diff --git a/Cargo.toml b/Cargo.toml index 59519a6..ebed5db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,8 +45,9 @@ paste = "0.1.12" trybuild = "1.0.26" [features] -default = ["errorlog"] +default = ["cors", "errorlog"] auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"] +cors = [] errorlog = [] database = ["gotham_restful_derive/database", "gotham_middleware_diesel"] openapi = ["gotham_restful_derive/openapi", "indexmap", "openapiv3"] diff --git a/src/cors.rs b/src/cors.rs new file mode 100644 index 0000000..57e1a10 --- /dev/null +++ b/src/cors.rs @@ -0,0 +1,246 @@ +use crate::matcher::AccessControlRequestMethodMatcher; +use gotham::{ + handler::HandlerFuture, + helpers::http::response::create_empty_response, + hyper::{ + header::{ + ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY, + HeaderMap, HeaderName, HeaderValue + }, + Body, Method, Response, StatusCode + }, + middleware::Middleware, + pipeline::chain::PipelineHandleChain, + router::builder::*, + state::{FromState, State}, +}; +use itertools::Itertools; +use std::{ + panic::RefUnwindSafe, + pin::Pin +}; + +/** +Specify the allowed origins of the request. It is up to the browser to check the validity of the +origin. This, when sent to the browser, will indicate whether or not the request's origin was +allowed to make the request. +*/ +#[derive(Clone, Debug)] +pub enum Origin +{ + /// Do not send any `Access-Control-Allow-Origin` headers. + None, + /// Send `Access-Control-Allow-Origin: *`. Note that browser will not send credentials. + Star, + /// Set the `Access-Control-Allow-Origin` header to a single origin. + Single(String), + /// Copy the `Origin` header into the `Access-Control-Allow-Origin` header. + Copy +} + +impl Default for Origin +{ + fn default() -> Self + { + Self::None + } +} + +impl Origin +{ + /// Get the header value for the `Access-Control-Allow-Origin` header. + fn header_value(&self, state : &State) -> Option + { + match self { + Self::None => None, + Self::Star => Some("*".parse().unwrap()), + Self::Single(origin) => Some(origin.parse().unwrap()), + Self::Copy => { + let headers = HeaderMap::borrow_from(state); + headers.get(ORIGIN).map(Clone::clone) + } + } + } +} + +/** +This is the configuration that the CORS handler will follow. Its default configuration is basically +not to touch any responses, resulting in the browser's default behaviour. + +To change settings, you need to put this type into gotham's [`State`]: + +```rust,no_run +# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State}; +# use gotham_restful::*; +fn main() { + let cors = CorsConfig { + origin: Origin::Star, + ..Default::default() + }; + let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build()); + gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| { + // your routing logic + })); +} +``` + +This easy approach allows you to have one global cors configuration. If you prefer to have separate +configurations for different scopes, you need to register the middleware inside your routing logic: + +```rust,no_run +# use gotham::{router::builder::*, pipeline::*, pipeline::set::*, state::State}; +# use gotham_restful::*; +fn main() { + let pipelines = new_pipeline_set(); + + let cors_a = CorsConfig { + origin: Origin::Star, + ..Default::default() + }; + let (pipelines, chain_a) = pipelines.add( + new_pipeline().add(cors_a).build() + ); + + let cors_b = CorsConfig { + origin: Origin::Copy, + ..Default::default() + }; + let (pipelines, chain_b) = pipelines.add( + new_pipeline().add(cors_b).build() + ); + + let pipeline_set = finalize_pipeline_set(pipelines); + gotham::start("127.0.0.1:8080", build_router((), pipeline_set, |route| { + // routing without any cors config + route.with_pipeline_chain((chain_a, ()), |route| { + // routing with cors config a + }); + route.with_pipeline_chain((chain_b, ()), |route| { + // routing with cors config b + }); + })); +} +``` + + [`State`]: ../gotham/state/struct.State.html +*/ +#[derive(Clone, Debug, Default, NewMiddleware, StateData)] +pub struct CorsConfig +{ + /// The allowed origins. + pub origin : Origin, + /// The allowed headers. + pub headers : Vec, + /// The amount of seconds that the preflight request can be cached. + pub max_age : u64, + /// Whether or not the request may be made with supplying credentials. + pub credentials : bool +} + +impl Middleware for CorsConfig +{ + fn call(self, mut state : State, chain : Chain) -> Pin> + where + Chain : FnOnce(State) -> Pin> + { + state.put(self); + chain(state) + } +} + +/** +Handle CORS for a non-preflight request. This means manipulating the `res` HTTP headers so that +the response is aligned with the `state`'s [`CorsConfig`]. + +If you are using the [`Resource`] type (which is the recommended way), you'll never have to call +this method. However, if you are writing your own handler method, you might want to call this +after your request to add the required CORS headers. + +For further information on CORS, read https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS. + + [`CorsConfig`]: ./struct.CorsConfig.html +*/ +pub fn handle_cors(state : &State, res : &mut Response) +{ + let config = CorsConfig::try_borrow_from(state); + let headers = res.headers_mut(); + + // non-preflight requests require the Access-Control-Allow-Origin header + if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state)) + { + headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); + } + + // if the origin is copied over, we should tell the browser by specifying the Vary header + if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy)) + { + let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap())); + headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap()); + } + + // if we allow credentials, tell the browser + if config.map(|cfg| cfg.credentials).unwrap_or(false) + { + headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap()); + } +} + +/// Add CORS routing for your path. +pub trait CorsRoute +where + C : PipelineHandleChain

+ Copy + Send + Sync + 'static, + P : RefUnwindSafe + Send + Sync + 'static +{ + fn cors(&mut self, path : &str, method : Method); +} + +fn cors_preflight_handler(state : State) -> (State, Response) +{ + let config = CorsConfig::try_borrow_from(&state); + + // prepare the response + let mut res = create_empty_response(&state, StatusCode::NO_CONTENT); + let headers = res.headers_mut(); + + // copy the request method over to the response + let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone(); + headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method); + + // if we allow any headers, put them in + if let Some(hdrs) = config.map(|cfg| &cfg.headers) + { + if hdrs.len() > 0 + { + // TODO do we want to return all headers or just those asked by the browser? + headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap()); + } + } + + // set the max age for the preflight cache + if let Some(age) = config.map(|cfg| cfg.max_age) + { + headers.insert(ACCESS_CONTROL_MAX_AGE, age.into()); + } + + // make sure the browser knows that this request was based on the method + headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap()); + + handle_cors(&state, &mut res); + (state, res) +} + +impl CorsRoute for D +where + D : DrawRoutes, + C : PipelineHandleChain

+ Copy + Send + Sync + 'static, + P : RefUnwindSafe + Send + Sync + 'static +{ + fn cors(&mut self, path : &str, method : Method) + { + let matcher = AccessControlRequestMethodMatcher::new(method); + self.options(path) + .extend_route_matcher(matcher) + .to(cors_preflight_handler); + } +} diff --git a/src/lib.rs b/src/lib.rs index 1fa19fd..ff60998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -285,6 +285,16 @@ pub use auth::{ StaticAuthHandler }; +#[cfg(feature = "cors")] +mod cors; +#[cfg(feature = "cors")] +pub use cors::{ + handle_cors, + CorsConfig, + CorsRoute, + Origin +}; + pub mod matcher; #[cfg(feature = "openapi")] diff --git a/src/matcher/access_control_request_method.rs b/src/matcher/access_control_request_method.rs new file mode 100644 index 0000000..6c912cc --- /dev/null +++ b/src/matcher/access_control_request_method.rs @@ -0,0 +1,57 @@ +use gotham::{ + hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode}, + router::{non_match::RouteNonMatch, route::matcher::RouteMatcher}, + state::{FromState, State} +}; + +/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value. +/// +/// Usage: +/// +/// ```rust +/// # use gotham::{helpers::http::response::create_empty_response, +/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode}, +/// # router::builder::* +/// # }; +/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher; +/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT); +/// +/// # build_simple_router(|route| { +/// // use the matcher for your request +/// route.options("/foo") +/// .extend_route_matcher(matcher) +/// .to(|state| { +/// // we know that this is a CORS preflight for a PUT request +/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT); +/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap()); +/// (state, res) +/// }); +/// # }); +/// ``` +#[derive(Clone, Debug)] +pub struct AccessControlRequestMethodMatcher +{ + method : Method +} + +impl AccessControlRequestMethodMatcher +{ + pub fn new(method : Method) -> Self + { + Self { method } + } +} + +impl RouteMatcher for AccessControlRequestMethodMatcher +{ + fn is_match(&self, state : &State) -> Result<(), RouteNonMatch> + { + match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD) + .and_then(|value| value.to_str().ok()) + .and_then(|str| str.parse::().ok()) + { + Some(m) if m == self.method => Ok(()), + _ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND)) + } + } +} diff --git a/src/matcher/mod.rs b/src/matcher/mod.rs index 4d5268e..9cbfcbb 100644 --- a/src/matcher/mod.rs +++ b/src/matcher/mod.rs @@ -8,6 +8,11 @@ pub use accept::AcceptHeaderMatcher; mod content_type; pub use content_type::ContentTypeMatcher; +#[cfg(feature = "cors")] +mod access_control_request_method; +#[cfg(feature = "cors")] +pub use access_control_request_method::AccessControlRequestMethodMatcher; + type LookupTable = HashMap>; trait LookupTableFromTypes diff --git a/src/routing.rs b/src/routing.rs index 1b0aa46..916b244 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -6,6 +6,8 @@ use crate::{ Response, StatusCode }; +#[cfg(feature = "cors")] +use crate::CorsRoute; #[cfg(feature = "openapi")] use crate::openapi::{ builder::{OpenapiBuilder, OpenapiInfo}, @@ -100,10 +102,16 @@ fn response_from(res : Response, state : &State) -> gotham::hyper::Response(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::POST); } fn change_all(&mut self) @@ -398,6 +408,8 @@ macro_rules! implDrawResourceRoutes { .extend_route_matcher(accept_matcher) .extend_route_matcher(content_matcher) .to(|state| change_all_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::PUT); } fn change(&mut self) @@ -407,11 +419,14 @@ macro_rules! implDrawResourceRoutes { { let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into(); - self.0.put(&format!("{}/:id", self.1)) + let path = format!("{}/:id", self.1); + self.0.put(&path) .extend_route_matcher(accept_matcher) .extend_route_matcher(content_matcher) .with_path_extractor::>() .to(|state| change_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&path, Method::PUT); } fn remove_all(&mut self) @@ -420,15 +435,20 @@ macro_rules! implDrawResourceRoutes { self.0.delete(&self.1) .extend_route_matcher(matcher) .to(|state| remove_all_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&self.1, Method::DELETE); } fn remove(&mut self) { let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); - self.0.delete(&format!("{}/:id", self.1)) + let path = format!("{}/:id", self.1); + self.0.delete(&path) .extend_route_matcher(matcher) .with_path_extractor::>() .to(|state| remove_handler::(state)); + #[cfg(feature = "cors")] + self.0.cors(&path, Method::POST); } } } diff --git a/tests/cors_handling.rs b/tests/cors_handling.rs new file mode 100644 index 0000000..a9fe498 --- /dev/null +++ b/tests/cors_handling.rs @@ -0,0 +1,156 @@ +#![cfg(feature = "cors")] +use gotham::{ + hyper::{body::Body, client::connect::Connect, header::*, StatusCode}, + pipeline::{new_pipeline, single::single_pipeline}, + router::builder::*, + test::{Server, TestRequest, TestServer} +}; +use gotham_restful::{CorsConfig, DrawResources, Origin, Raw, Resource, change_all, read_all}; +use itertools::Itertools; +use mime::TEXT_PLAIN; + +#[derive(Resource)] +#[resource(read_all, change_all)] +struct FooResource; + +#[read_all(FooResource)] +fn read_all() +{ +} + +#[change_all(FooResource)] +fn change_all(_body : Raw>) +{ +} + +fn test_server(cfg : CorsConfig) -> TestServer +{ + let (chain, pipeline) = single_pipeline(new_pipeline().add(cfg).build()); + TestServer::new(build_router(chain, pipeline, |router| { + router.resource::("/foo") + })).unwrap() +} + +fn test_response(req : TestRequest, origin : Option<&str>, vary : Option<&str>, credentials : bool) +where + TS : Server + 'static, + C : Connect + Clone + Send + Sync + 'static +{ + let res = req.with_header(ORIGIN, "http://example.org".parse().unwrap()).perform().unwrap(); + assert_eq!(res.status(), StatusCode::NO_CONTENT); + let headers = res.headers(); + println!("{}", headers.keys().join(",")); + assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin); + assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), vary); + assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials); + assert!(headers.get(ACCESS_CONTROL_MAX_AGE).is_none()); +} + +fn test_preflight(server : &TestServer, method : &str, origin : Option<&str>, vary : &str, credentials : bool, max_age : u64) +{ + let res = server.client().options("http://example.org/foo") + .with_header(ACCESS_CONTROL_REQUEST_METHOD, method.parse().unwrap()) + .with_header(ORIGIN, "http://example.org".parse().unwrap()) + .perform().unwrap(); + assert_eq!(res.status(), StatusCode::NO_CONTENT); + let headers = res.headers(); + println!("{}", headers.keys().join(",")); + assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).and_then(|value| value.to_str().ok()).as_deref(), Some(method)); + assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin); + assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), Some(vary)); + assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials); + assert_eq!(headers.get(ACCESS_CONTROL_MAX_AGE).and_then(|value| value.to_str().ok()).and_then(|value| value.parse().ok()), Some(max_age)); +} + + +#[test] +fn cors_origin_none() +{ + let cfg = CorsConfig { + origin: Origin::None, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 0); + + test_response(server.client().get("http://example.org/foo"), None, None, false); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false); +} + +#[test] +fn cors_origin_star() +{ + let cfg = CorsConfig { + origin: Origin::Star, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", Some("*"), "Access-Control-Request-Method", false, 0); + + test_response(server.client().get("http://example.org/foo"), Some("*"), None, false); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("*"), None, false); +} + +#[test] +fn cors_origin_single() +{ + let cfg = CorsConfig { + origin: Origin::Single("https://foo.com".to_owned()), + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", Some("https://foo.com"), "Access-Control-Request-Method", false, 0); + + test_response(server.client().get("http://example.org/foo"), Some("https://foo.com"), None, false); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("https://foo.com"), None, false); +} + +#[test] +fn cors_origin_copy() +{ + let cfg = CorsConfig { + origin: Origin::Copy, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", Some("http://example.org"), "Access-Control-Request-Method,Origin", false, 0); + + test_response(server.client().get("http://example.org/foo"), Some("http://example.org"), Some("Origin"), false); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("http://example.org"), Some("Origin"), false); +} + +#[test] +fn cors_credentials() +{ + let cfg = CorsConfig { + origin: Origin::None, + credentials: true, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", None, "Access-Control-Request-Method", true, 0); + + test_response(server.client().get("http://example.org/foo"), None, None, true); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, true); +} + +#[test] +fn cors_max_age() +{ + let cfg = CorsConfig { + origin: Origin::None, + max_age: 31536000, + ..Default::default() + }; + let server = test_server(cfg); + + test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 31536000); + + test_response(server.client().get("http://example.org/foo"), None, None, false); + test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false); +} diff --git a/tests/openapi_supports_scope.rs b/tests/openapi_supports_scope.rs index 62228da..3b9aa2c 100644 --- a/tests/openapi_supports_scope.rs +++ b/tests/openapi_supports_scope.rs @@ -1,8 +1,4 @@ -#[cfg(feature = "openapi")] -mod openapi_supports_scope -{ - - +#![cfg(feature = "openapi")] use gotham::{ router::builder::*, test::TestServer @@ -29,7 +25,7 @@ fn read_all() -> Raw<&'static [u8]> #[test] -fn test() +fn openapi_supports_scope() { let info = OpenapiInfo { title: "Test".to_owned(), @@ -54,6 +50,3 @@ fn test() test_get_response(&server, "http://localhost/bar/baz/foo3", RESPONSE); test_get_response(&server, "http://localhost/foo4", RESPONSE); } - - -} // mod test