From f20c768d02ccf6c65a1c3a4d784cff4b2bdc8c2e Mon Sep 17 00:00:00 2001 From: Dominic Date: Thu, 14 May 2020 23:30:59 +0200 Subject: [PATCH] cors preflight --- Cargo.toml | 2 +- src/cors.rs | 112 +++++++++++++++++-- src/lib.rs | 1 + src/matcher/access_control_request_method.rs | 57 ++++++++++ src/matcher/mod.rs | 4 + src/routing.rs | 18 ++- 6 files changed, 182 insertions(+), 12 deletions(-) create mode 100644 src/matcher/access_control_request_method.rs diff --git a/Cargo.toml b/Cargo.toml index 618f11d..ebed5db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ paste = "0.1.12" trybuild = "1.0.26" [features] -default = ["errorlog"] +default = ["cors", "errorlog"] auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"] cors = [] errorlog = [] diff --git a/src/cors.rs b/src/cors.rs index c48a38a..558bedb 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -1,13 +1,25 @@ +use crate::matcher::AccessControlRequestMethodMatcher; use gotham::{ handler::HandlerFuture, + helpers::http::response::create_empty_response, hyper::{ - header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN, HeaderMap, HeaderValue}, - Body, Method, Response + header::{ + ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY, + HeaderMap, HeaderName, HeaderValue + }, + Body, Method, Response, StatusCode }, middleware::Middleware, + pipeline::chain::PipelineHandleChain, + router::builder::*, state::{FromState, State}, }; -use std::pin::Pin; +use itertools::Itertools; +use std::{ + panic::RefUnwindSafe, + pin::Pin +}; /** Specify the allowed origins of the request. It is up to the browser to check the validity of the @@ -63,7 +75,8 @@ To change settings, you need to put this type into gotham's [`State`]: # use gotham_restful::*; fn main() { let cors = CorsConfig { - origin: Origin::Star + origin: Origin::Star, + ..Default::default() }; let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build()); gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| { @@ -82,14 +95,16 @@ fn main() { let pipelines = new_pipeline_set(); let cors_a = CorsConfig { - origin: Origin::Star + origin: Origin::Star, + ..Default::default() }; let (pipelines, chain_a) = pipelines.add( new_pipeline().add(cors_a).build() ); let cors_b = CorsConfig { - origin: Origin::Copy + origin: Origin::Copy, + ..Default::default() }; let (pipelines, chain_b) = pipelines.add( new_pipeline().add(cors_b).build() @@ -113,7 +128,14 @@ fn main() { #[derive(Clone, Debug, Default, NewMiddleware, StateData)] pub struct CorsConfig { - pub origin : Origin + /// The allowed origins. + pub origin : Origin, + /// The allowed headers. + pub headers : Vec, + /// The amount of seconds that the preflight request can be cached. + pub max_age : u64, + /// Whether or not the request may be made with supplying credentials. + pub credentials : bool } impl Middleware for CorsConfig @@ -141,12 +163,84 @@ For further information on CORS, read https://developer.mozilla.org/en-US/docs/W */ pub fn handle_cors(state : &State, res : &mut Response) { - let method = Method::borrow_from(state); let config = CorsConfig::try_borrow_from(state); + let headers = res.headers_mut(); // non-preflight requests require nothing other than the Access-Control-Allow-Origin header if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state)) { - res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); + } + + // if the origin is copied over, we should tell the browser by specifying the Vary header + if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy)) + { + let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap())); + headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap()); + } + + // if we allow credentials, tell the browser + if config.map(|cfg| cfg.credentials).unwrap_or(false) + { + headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap()); + } +} + +/// Add CORS routing for your path. +pub trait CorsRoute +where + C : PipelineHandleChain

+ Copy + Send + Sync + 'static, + P : RefUnwindSafe + Send + Sync + 'static +{ + fn cors(&mut self, path : &str, method : Method); +} + +fn cors_preflight_handler(state : State) -> (State, Response) +{ + let config = CorsConfig::try_borrow_from(&state); + + // prepare the response + let mut res = create_empty_response(&state, StatusCode::NO_CONTENT); + let headers = res.headers_mut(); + + // copy the request method over to the response + let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone(); + headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method); + + // if we allow any headers, put them in + if let Some(hdrs) = config.map(|cfg| &cfg.headers) + { + if hdrs.len() > 0 + { + // TODO do we want to return all headers or just those asked by the browser? + headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap()); + } + } + + // set the max age for the preflight cache + if let Some(age) = config.map(|cfg| cfg.max_age) + { + headers.insert(ACCESS_CONTROL_MAX_AGE, age.into()); + } + + // make sure the browser knows that this request was based on the method + headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap()); + + handle_cors(&state, &mut res); + (state, res) +} + +impl CorsRoute for D +where + D : DrawRoutes, + C : PipelineHandleChain

+ Copy + Send + Sync + 'static, + P : RefUnwindSafe + Send + Sync + 'static +{ + fn cors(&mut self, path : &str, method : Method) + { + let matcher = AccessControlRequestMethodMatcher::new(method); + self.options(path) + .extend_route_matcher(matcher) + .to(cors_preflight_handler); } } diff --git a/src/lib.rs b/src/lib.rs index d45363d..ff60998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -291,6 +291,7 @@ mod cors; pub use cors::{ handle_cors, CorsConfig, + CorsRoute, Origin }; diff --git a/src/matcher/access_control_request_method.rs b/src/matcher/access_control_request_method.rs new file mode 100644 index 0000000..6c912cc --- /dev/null +++ b/src/matcher/access_control_request_method.rs @@ -0,0 +1,57 @@ +use gotham::{ + hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode}, + router::{non_match::RouteNonMatch, route::matcher::RouteMatcher}, + state::{FromState, State} +}; + +/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value. +/// +/// Usage: +/// +/// ```rust +/// # use gotham::{helpers::http::response::create_empty_response, +/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode}, +/// # router::builder::* +/// # }; +/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher; +/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT); +/// +/// # build_simple_router(|route| { +/// // use the matcher for your request +/// route.options("/foo") +/// .extend_route_matcher(matcher) +/// .to(|state| { +/// // we know that this is a CORS preflight for a PUT request +/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT); +/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap()); +/// (state, res) +/// }); +/// # }); +/// ``` +#[derive(Clone, Debug)] +pub struct AccessControlRequestMethodMatcher +{ + method : Method +} + +impl AccessControlRequestMethodMatcher +{ + pub fn new(method : Method) -> Self + { + Self { method } + } +} + +impl RouteMatcher for AccessControlRequestMethodMatcher +{ + fn is_match(&self, state : &State) -> Result<(), RouteNonMatch> + { + match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD) + .and_then(|value| value.to_str().ok()) + .and_then(|str| str.parse::().ok()) + { + Some(m) if m == self.method => Ok(()), + _ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND)) + } + } +} diff --git a/src/matcher/mod.rs b/src/matcher/mod.rs index 4d5268e..3168ec3 100644 --- a/src/matcher/mod.rs +++ b/src/matcher/mod.rs @@ -8,6 +8,10 @@ pub use accept::AcceptHeaderMatcher; mod content_type; pub use content_type::ContentTypeMatcher; +#[cfg(feature = "cors")] +mod access_control_request_method; +pub use access_control_request_method::AccessControlRequestMethodMatcher; + type LookupTable = HashMap>; trait LookupTableFromTypes diff --git a/src/routing.rs b/src/routing.rs index 320cd7a..916b244 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -6,6 +6,8 @@ use crate::{ Response, StatusCode }; +#[cfg(feature = "cors")] +use crate::CorsRoute; #[cfg(feature = "openapi")] use crate::openapi::{ builder::{OpenapiBuilder, OpenapiInfo}, @@ -391,6 +393,8 @@ macro_rules! implDrawResourceRoutes { .extend_route_matcher(accept_matcher) .extend_route_matcher(content_matcher) .to(|state| create_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::POST); } fn change_all(&mut self) @@ -404,6 +408,8 @@ macro_rules! implDrawResourceRoutes { .extend_route_matcher(accept_matcher) .extend_route_matcher(content_matcher) .to(|state| change_all_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::PUT); } fn change(&mut self) @@ -413,11 +419,14 @@ macro_rules! implDrawResourceRoutes { { let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into(); - self.0.put(&format!("{}/:id", self.1)) + let path = format!("{}/:id", self.1); + self.0.put(&path) .extend_route_matcher(accept_matcher) .extend_route_matcher(content_matcher) .with_path_extractor::>() .to(|state| change_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&path, Method::PUT); } fn remove_all(&mut self) @@ -426,15 +435,20 @@ macro_rules! implDrawResourceRoutes { self.0.delete(&self.1) .extend_route_matcher(matcher) .to(|state| remove_all_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::DELETE); } fn remove(&mut self) { let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); - self.0.delete(&format!("{}/:id", self.1)) + let path = format!("{}/:id", self.1); + self.0.delete(&path) .extend_route_matcher(matcher) .with_path_extractor::>() .to(|state| remove_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&path, Method::POST); } } }