From 766bc9d17d8481966c435dc5e19d97f3e1642933 Mon Sep 17 00:00:00 2001 From: Dominic Date: Fri, 1 Jan 2021 16:44:55 +0100 Subject: [PATCH] support copying headers in cors preflight requests --- CHANGELOG.md | 2 + Cargo.toml | 3 +- src/cors.rs | 111 ++++++++++++++++++++++++++++++----------- src/lib.rs | 8 +-- tests/cors_handling.rs | 111 ++++++++++++++++++++++++++++++++++++----- 5 files changed, 188 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a561d6f..bf19e66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed + - The cors handler can now copy headers from the request if desired ## [0.1.1] - 2020-12-28 ### Added diff --git a/Cargo.toml b/Cargo.toml index b78e7c3..2d7d073 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ gotham_derive = "0.5.0" gotham_middleware_diesel = { version = "0.2.0", optional = true } gotham_restful_derive = "0.2.0-dev" indexmap = { version = "1.3.2", optional = true } -itertools = { version = "0.10.0", optional = true } jsonwebtoken = { version = "7.1.0", optional = true } log = "0.4.8" mime = "0.3.16" @@ -49,7 +48,7 @@ trybuild = "1.0.27" [features] default = ["cors", "errorlog"] auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"] -cors = ["itertools"] +cors = [] errorlog = [] database = ["gotham_restful_derive/database", "gotham_middleware_diesel"] openapi = ["gotham_restful_derive/openapi", "indexmap", "openapiv3"] diff --git a/src/cors.rs b/src/cors.rs index 5463de1..d8f2988 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -5,7 +5,7 @@ use gotham::{ header::{ HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, - ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY + ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY }, Body, Method, Response, StatusCode }, @@ -14,7 +14,6 @@ use gotham::{ router::{builder::*, route::matcher::AccessControlRequestMethodMatcher}, state::{FromState, State} }; -use itertools::Itertools; use std::{panic::RefUnwindSafe, pin::Pin}; /** @@ -53,6 +52,52 @@ impl Origin { } } } + + /// Returns true if the `Vary` header has to include `Origin`. + fn varies(&self) -> bool { + matches!(self, Self::Copy) + } +} + +/** +Specify the allowed headers of the request. It is up to the browser to check that only the allowed +headers are sent with the request. +*/ +#[derive(Clone, Debug)] +pub enum Headers { + /// Do not send any `Access-Control-Allow-Headers` headers. + None, + /// Set the `Access-Control-Allow-Headers` header to the following header list. If empty, this + /// is treated as if it was [None]. + List(Vec), + /// Copy the `Access-Control-Request-Headers` header into the `Access-Control-Allow-Header` + /// header. + Copy +} + +impl Default for Headers { + fn default() -> Self { + Self::None + } +} + +impl Headers { + /// Get the header value for the `Access-Control-Allow-Headers` header. + fn header_value(&self, state: &State) -> Option { + match self { + Self::None => None, + Self::List(list) => Some(list.join(",").parse().unwrap()), + Self::Copy => { + let headers = HeaderMap::borrow_from(state); + headers.get(ACCESS_CONTROL_REQUEST_HEADERS).map(Clone::clone) + } + } + } + + /// Returns true if the `Vary` header has to include `Origin`. + fn varies(&self) -> bool { + matches!(self, Self::Copy) + } } /** @@ -63,7 +108,7 @@ To change settings, you need to put this type into gotham's [State]: ```rust,no_run # use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State}; -# use gotham_restful::*; +# use gotham_restful::{*, cors::Origin}; fn main() { let cors = CorsConfig { origin: Origin::Star, @@ -81,7 +126,7 @@ configurations for different scopes, you need to register the middleware inside ```rust,no_run # use gotham::{router::builder::*, pipeline::*, pipeline::set::*, state::State}; -# use gotham_restful::*; +# use gotham_restful::{*, cors::Origin}; let pipelines = new_pipeline_set(); // The first cors configuration @@ -119,7 +164,7 @@ pub struct CorsConfig { /// The allowed origins. pub origin: Origin, /// The allowed headers. - pub headers: Vec, + pub headers: Headers, /// The amount of seconds that the preflight request can be cached. pub max_age: u64, /// Whether or not the request may be made with supplying credentials. @@ -149,22 +194,24 @@ For further information on CORS, read */ pub fn handle_cors(state: &State, res: &mut Response) { let config = CorsConfig::try_borrow_from(state); - let headers = res.headers_mut(); + if let Some(cfg) = config { + let headers = res.headers_mut(); - // non-preflight requests require the Access-Control-Allow-Origin header - if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state)) { - headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); - } + // non-preflight requests require the Access-Control-Allow-Origin header + if let Some(header) = cfg.origin.header_value(state) { + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); + } - // if the origin is copied over, we should tell the browser by specifying the Vary header - if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy)) { - let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap())); - headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap()); - } + // if the origin is copied over, we should tell the browser by specifying the Vary header + if cfg.origin.varies() { + let vary = headers.get(VARY).map(|vary| format!("{},origin", vary.to_str().unwrap())); + headers.insert(VARY, vary.as_deref().unwrap_or("origin").parse().unwrap()); + } - // if we allow credentials, tell the browser - if config.map(|cfg| cfg.credentials).unwrap_or(false) { - headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap()); + // if we allow credentials, tell the browser + if cfg.credentials { + headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); + } } } @@ -202,6 +249,7 @@ fn cors_preflight_handler(state: State) -> (State, Response) { // prepare the response let mut res = create_empty_response(&state, StatusCode::NO_CONTENT); let headers = res.headers_mut(); + let mut vary: Vec = Vec::new(); // copy the request method over to the response let method = HeaderMap::borrow_from(&state) @@ -209,22 +257,27 @@ fn cors_preflight_handler(state: State) -> (State, Response) { .unwrap() .clone(); headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method); + vary.push(ACCESS_CONTROL_REQUEST_METHOD); - // if we allow any headers, put them in - if let Some(hdrs) = config.map(|cfg| &cfg.headers) { - if hdrs.len() > 0 { - // TODO do we want to return all headers or just those asked by the browser? - headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap()); + if let Some(cfg) = config { + // if we allow any headers, copy them over + if let Some(header) = cfg.headers.header_value(&state) { + headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, header); + } + + // if the headers are copied over, we should tell the browser by specifying the Vary header + if cfg.headers.varies() { + vary.push(ACCESS_CONTROL_REQUEST_HEADERS); + } + + // set the max age for the preflight cache + if let Some(age) = config.map(|cfg| cfg.max_age) { + headers.insert(ACCESS_CONTROL_MAX_AGE, age.into()); } } - // set the max age for the preflight cache - if let Some(age) = config.map(|cfg| cfg.max_age) { - headers.insert(ACCESS_CONTROL_MAX_AGE, age.into()); - } - // make sure the browser knows that this request was based on the method - headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap()); + headers.insert(VARY, vary.join(",").parse().unwrap()); handle_cors(&state, &mut res); (state, res) diff --git a/src/lib.rs b/src/lib.rs index 1d234e9..0ca5549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -196,7 +196,7 @@ authentication), and every content type, could look like this: # #[cfg(feature = "cors")] # mod cors_feature_enabled { # use gotham::{hyper::header::*, router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State}; -# use gotham_restful::*; +# use gotham_restful::{*, cors::*}; # use serde::{Deserialize, Serialize}; #[derive(Resource)] #[resource(read_all)] @@ -210,7 +210,7 @@ fn read_all() { fn main() { let cors = CorsConfig { origin: Origin::Copy, - headers: vec![CONTENT_TYPE], + headers: Headers::List(vec![CONTENT_TYPE]), max_age: 0, credentials: true }; @@ -417,9 +417,9 @@ mod auth; pub use auth::{AuthHandler, AuthMiddleware, AuthSource, AuthStatus, AuthValidation, StaticAuthHandler}; #[cfg(feature = "cors")] -mod cors; +pub mod cors; #[cfg(feature = "cors")] -pub use cors::{handle_cors, CorsConfig, CorsRoute, Origin}; +pub use cors::{handle_cors, CorsConfig, CorsRoute}; #[cfg(feature = "openapi")] mod openapi; diff --git a/tests/cors_handling.rs b/tests/cors_handling.rs index 35d5841..226b3d2 100644 --- a/tests/cors_handling.rs +++ b/tests/cors_handling.rs @@ -5,8 +5,11 @@ use gotham::{ router::builder::*, test::{Server, TestRequest, TestServer} }; -use gotham_restful::{change_all, read_all, CorsConfig, DrawResources, Origin, Raw, Resource}; -use itertools::Itertools; +use gotham_restful::{ + change_all, + cors::{Headers, Origin}, + read_all, CorsConfig, DrawResources, Raw, Resource +}; use mime::TEXT_PLAIN; #[derive(Resource)] @@ -35,7 +38,7 @@ where .unwrap(); assert_eq!(res.status(), StatusCode::NO_CONTENT); let headers = res.headers(); - println!("{}", headers.keys().join(",")); + println!("{}", headers.keys().map(|name| name.as_str()).collect::>().join(",")); assert_eq!( headers .get(ACCESS_CONTROL_ALLOW_ORIGIN) @@ -65,7 +68,7 @@ fn test_preflight(server: &TestServer, method: &str, origin: Option<&str>, vary: .unwrap(); assert_eq!(res.status(), StatusCode::NO_CONTENT); let headers = res.headers(); - println!("{}", headers.keys().join(",")); + println!("{}", headers.keys().map(|name| name.as_str()).collect::>().join(",")); assert_eq!( headers .get(ACCESS_CONTROL_ALLOW_METHODS) @@ -98,12 +101,45 @@ fn test_preflight(server: &TestServer, method: &str, origin: Option<&str>, vary: ); } +fn test_preflight_headers( + server: &TestServer, + method: &str, + request_headers: Option<&str>, + allowed_headers: Option<&str>, + vary: &str +) { + let client = server.client(); + let mut res = client + .options("http://example.org/foo") + .with_header(ACCESS_CONTROL_REQUEST_METHOD, method.parse().unwrap()) + .with_header(ORIGIN, "http://example.org".parse().unwrap()); + if let Some(hdr) = request_headers { + res = res.with_header(ACCESS_CONTROL_REQUEST_HEADERS, hdr.parse().unwrap()); + } + let res = res.perform().unwrap(); + assert_eq!(res.status(), StatusCode::NO_CONTENT); + let headers = res.headers(); + println!("{}", headers.keys().map(|name| name.as_str()).collect::>().join(",")); + if let Some(hdr) = allowed_headers { + assert_eq!( + headers + .get(ACCESS_CONTROL_ALLOW_HEADERS) + .and_then(|value| value.to_str().ok()) + .as_deref(), + Some(hdr) + ) + } else { + assert!(!headers.contains_key(ACCESS_CONTROL_ALLOW_HEADERS)); + } + assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), Some(vary)); +} + #[test] fn cors_origin_none() { let cfg = Default::default(); let server = test_server(cfg); - test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 0); + test_preflight(&server, "PUT", None, "access-control-request-method", false, 0); test_response(server.client().get("http://example.org/foo"), None, None, false); test_response( @@ -122,7 +158,7 @@ fn cors_origin_star() { }; let server = test_server(cfg); - test_preflight(&server, "PUT", Some("*"), "Access-Control-Request-Method", false, 0); + test_preflight(&server, "PUT", Some("*"), "access-control-request-method", false, 0); test_response(server.client().get("http://example.org/foo"), Some("*"), None, false); test_response( @@ -145,7 +181,7 @@ fn cors_origin_single() { &server, "PUT", Some("https://foo.com"), - "Access-Control-Request-Method", + "access-control-request-method", false, 0 ); @@ -176,7 +212,7 @@ fn cors_origin_copy() { &server, "PUT", Some("http://example.org"), - "Access-Control-Request-Method,Origin", + "access-control-request-method,origin", false, 0 ); @@ -184,17 +220,68 @@ fn cors_origin_copy() { test_response( server.client().get("http://example.org/foo"), Some("http://example.org"), - Some("Origin"), + Some("origin"), false ); test_response( server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("http://example.org"), - Some("Origin"), + Some("origin"), false ); } +#[test] +fn cors_headers_none() { + let cfg = Default::default(); + let server = test_server(cfg); + + test_preflight_headers(&server, "PUT", None, None, "access-control-request-method"); + test_preflight_headers(&server, "PUT", Some("Content-Type"), None, "access-control-request-method"); +} + +#[test] +fn cors_headers_list() { + let cfg = CorsConfig { + headers: Headers::List(vec![CONTENT_TYPE]), + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight_headers(&server, "PUT", None, Some("content-type"), "access-control-request-method"); + test_preflight_headers( + &server, + "PUT", + Some("content-type"), + Some("content-type"), + "access-control-request-method" + ); +} + +#[test] +fn cors_headers_copy() { + let cfg = CorsConfig { + headers: Headers::Copy, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight_headers( + &server, + "PUT", + None, + None, + "access-control-request-method,access-control-request-headers" + ); + test_preflight_headers( + &server, + "PUT", + Some("content-type"), + Some("content-type"), + "access-control-request-method,access-control-request-headers" + ); +} + #[test] fn cors_credentials() { let cfg = CorsConfig { @@ -204,7 +291,7 @@ fn cors_credentials() { }; let server = test_server(cfg); - test_preflight(&server, "PUT", None, "Access-Control-Request-Method", true, 0); + test_preflight(&server, "PUT", None, "access-control-request-method", true, 0); test_response(server.client().get("http://example.org/foo"), None, None, true); test_response( @@ -224,7 +311,7 @@ fn cors_max_age() { }; let server = test_server(cfg); - test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 31536000); + test_preflight(&server, "PUT", None, "access-control-request-method", false, 31536000); test_response(server.client().get("http://example.org/foo"), None, None, false); test_response(