1
0
Fork 0
mirror of https://gitlab.com/msrd0/gotham-restful.git synced 2025-02-23 04:52:28 +00:00

cors preflight

This commit is contained in:
Dominic 2020-05-14 23:30:59 +02:00
parent 748bf65d3e
commit f20c768d02
Signed by: msrd0
GPG key ID: DCC8C247452E98F9
6 changed files with 182 additions and 12 deletions

View file

@ -45,7 +45,7 @@ paste = "0.1.12"
trybuild = "1.0.26" trybuild = "1.0.26"
[features] [features]
default = ["errorlog"] default = ["cors", "errorlog"]
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"] auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
cors = [] cors = []
errorlog = [] errorlog = []

View file

@ -1,13 +1,25 @@
use crate::matcher::AccessControlRequestMethodMatcher;
use gotham::{ use gotham::{
handler::HandlerFuture, handler::HandlerFuture,
helpers::http::response::create_empty_response,
hyper::{ hyper::{
header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN, HeaderMap, HeaderValue}, header::{
Body, Method, Response ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
HeaderMap, HeaderName, HeaderValue
},
Body, Method, Response, StatusCode
}, },
middleware::Middleware, middleware::Middleware,
pipeline::chain::PipelineHandleChain,
router::builder::*,
state::{FromState, State}, state::{FromState, State},
}; };
use std::pin::Pin; use itertools::Itertools;
use std::{
panic::RefUnwindSafe,
pin::Pin
};
/** /**
Specify the allowed origins of the request. It is up to the browser to check the validity of the Specify the allowed origins of the request. It is up to the browser to check the validity of the
@ -63,7 +75,8 @@ To change settings, you need to put this type into gotham's [`State`]:
# use gotham_restful::*; # use gotham_restful::*;
fn main() { fn main() {
let cors = CorsConfig { let cors = CorsConfig {
origin: Origin::Star origin: Origin::Star,
..Default::default()
}; };
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build()); let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| { gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
@ -82,14 +95,16 @@ fn main() {
let pipelines = new_pipeline_set(); let pipelines = new_pipeline_set();
let cors_a = CorsConfig { let cors_a = CorsConfig {
origin: Origin::Star origin: Origin::Star,
..Default::default()
}; };
let (pipelines, chain_a) = pipelines.add( let (pipelines, chain_a) = pipelines.add(
new_pipeline().add(cors_a).build() new_pipeline().add(cors_a).build()
); );
let cors_b = CorsConfig { let cors_b = CorsConfig {
origin: Origin::Copy origin: Origin::Copy,
..Default::default()
}; };
let (pipelines, chain_b) = pipelines.add( let (pipelines, chain_b) = pipelines.add(
new_pipeline().add(cors_b).build() new_pipeline().add(cors_b).build()
@ -113,7 +128,14 @@ fn main() {
#[derive(Clone, Debug, Default, NewMiddleware, StateData)] #[derive(Clone, Debug, Default, NewMiddleware, StateData)]
pub struct CorsConfig pub struct CorsConfig
{ {
pub origin : Origin /// The allowed origins.
pub origin : Origin,
/// The allowed headers.
pub headers : Vec<HeaderName>,
/// The amount of seconds that the preflight request can be cached.
pub max_age : u64,
/// Whether or not the request may be made with supplying credentials.
pub credentials : bool
} }
impl Middleware for CorsConfig impl Middleware for CorsConfig
@ -141,12 +163,84 @@ For further information on CORS, read https://developer.mozilla.org/en-US/docs/W
*/ */
pub fn handle_cors(state : &State, res : &mut Response<Body>) pub fn handle_cors(state : &State, res : &mut Response<Body>)
{ {
let method = Method::borrow_from(state);
let config = CorsConfig::try_borrow_from(state); let config = CorsConfig::try_borrow_from(state);
let headers = res.headers_mut();
// non-preflight requests require nothing other than the Access-Control-Allow-Origin header // non-preflight requests require nothing other than the Access-Control-Allow-Origin header
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state)) if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
{ {
res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, header); headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
}
// if the origin is copied over, we should tell the browser by specifying the Vary header
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy))
{
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
}
// if we allow credentials, tell the browser
if config.map(|cfg| cfg.credentials).unwrap_or(false)
{
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
}
}
/// Add CORS routing for your path.
pub trait CorsRoute<C, P>
where
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P : RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path : &str, method : Method);
}
fn cors_preflight_handler(state : State) -> (State, Response<Body>)
{
let config = CorsConfig::try_borrow_from(&state);
// prepare the response
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
let headers = res.headers_mut();
// copy the request method over to the response
let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone();
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
// if we allow any headers, put them in
if let Some(hdrs) = config.map(|cfg| &cfg.headers)
{
if hdrs.len() > 0
{
// TODO do we want to return all headers or just those asked by the browser?
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
}
}
// set the max age for the preflight cache
if let Some(age) = config.map(|cfg| cfg.max_age)
{
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
}
// make sure the browser knows that this request was based on the method
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
handle_cors(&state, &mut res);
(state, res)
}
impl<D, C, P> CorsRoute<C, P> for D
where
D : DrawRoutes<C, P>,
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P : RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path : &str, method : Method)
{
let matcher = AccessControlRequestMethodMatcher::new(method);
self.options(path)
.extend_route_matcher(matcher)
.to(cors_preflight_handler);
} }
} }

View file

@ -291,6 +291,7 @@ mod cors;
pub use cors::{ pub use cors::{
handle_cors, handle_cors,
CorsConfig, CorsConfig,
CorsRoute,
Origin Origin
}; };

View file

@ -0,0 +1,57 @@
use gotham::{
hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode},
router::{non_match::RouteNonMatch, route::matcher::RouteMatcher},
state::{FromState, State}
};
/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value.
///
/// Usage:
///
/// ```rust
/// # use gotham::{helpers::http::response::create_empty_response,
/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode},
/// # router::builder::*
/// # };
/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher;
/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT);
///
/// # build_simple_router(|route| {
/// // use the matcher for your request
/// route.options("/foo")
/// .extend_route_matcher(matcher)
/// .to(|state| {
/// // we know that this is a CORS preflight for a PUT request
/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap());
/// (state, res)
/// });
/// # });
/// ```
#[derive(Clone, Debug)]
pub struct AccessControlRequestMethodMatcher
{
method : Method
}
impl AccessControlRequestMethodMatcher
{
pub fn new(method : Method) -> Self
{
Self { method }
}
}
impl RouteMatcher for AccessControlRequestMethodMatcher
{
fn is_match(&self, state : &State) -> Result<(), RouteNonMatch>
{
match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD)
.and_then(|value| value.to_str().ok())
.and_then(|str| str.parse::<Method>().ok())
{
Some(m) if m == self.method => Ok(()),
_ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND))
}
}
}

View file

@ -8,6 +8,10 @@ pub use accept::AcceptHeaderMatcher;
mod content_type; mod content_type;
pub use content_type::ContentTypeMatcher; pub use content_type::ContentTypeMatcher;
#[cfg(feature = "cors")]
mod access_control_request_method;
pub use access_control_request_method::AccessControlRequestMethodMatcher;
type LookupTable = HashMap<String, Vec<usize>>; type LookupTable = HashMap<String, Vec<usize>>;
trait LookupTableFromTypes trait LookupTableFromTypes

View file

@ -6,6 +6,8 @@ use crate::{
Response, Response,
StatusCode StatusCode
}; };
#[cfg(feature = "cors")]
use crate::CorsRoute;
#[cfg(feature = "openapi")] #[cfg(feature = "openapi")]
use crate::openapi::{ use crate::openapi::{
builder::{OpenapiBuilder, OpenapiInfo}, builder::{OpenapiBuilder, OpenapiInfo},
@ -391,6 +393,8 @@ macro_rules! implDrawResourceRoutes {
.extend_route_matcher(accept_matcher) .extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher) .extend_route_matcher(content_matcher)
.to(|state| create_handler::<Handler>(state)); .to(|state| create_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::POST);
} }
fn change_all<Handler : ResourceChangeAll>(&mut self) fn change_all<Handler : ResourceChangeAll>(&mut self)
@ -404,6 +408,8 @@ macro_rules! implDrawResourceRoutes {
.extend_route_matcher(accept_matcher) .extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher) .extend_route_matcher(content_matcher)
.to(|state| change_all_handler::<Handler>(state)); .to(|state| change_all_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::PUT);
} }
fn change<Handler : ResourceChange>(&mut self) fn change<Handler : ResourceChange>(&mut self)
@ -413,11 +419,14 @@ macro_rules! implDrawResourceRoutes {
{ {
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into(); let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
self.0.put(&format!("{}/:id", self.1)) let path = format!("{}/:id", self.1);
self.0.put(&path)
.extend_route_matcher(accept_matcher) .extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher) .extend_route_matcher(content_matcher)
.with_path_extractor::<PathExtractor<Handler::ID>>() .with_path_extractor::<PathExtractor<Handler::ID>>()
.to(|state| change_handler::<Handler>(state)); .to(|state| change_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&path, Method::PUT);
} }
fn remove_all<Handler : ResourceRemoveAll>(&mut self) fn remove_all<Handler : ResourceRemoveAll>(&mut self)
@ -426,15 +435,20 @@ macro_rules! implDrawResourceRoutes {
self.0.delete(&self.1) self.0.delete(&self.1)
.extend_route_matcher(matcher) .extend_route_matcher(matcher)
.to(|state| remove_all_handler::<Handler>(state)); .to(|state| remove_all_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::DELETE);
} }
fn remove<Handler : ResourceRemove>(&mut self) fn remove<Handler : ResourceRemove>(&mut self)
{ {
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into(); let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
self.0.delete(&format!("{}/:id", self.1)) let path = format!("{}/:id", self.1);
self.0.delete(&path)
.extend_route_matcher(matcher) .extend_route_matcher(matcher)
.with_path_extractor::<PathExtractor<Handler::ID>>() .with_path_extractor::<PathExtractor<Handler::ID>>()
.to(|state| remove_handler::<Handler>(state)); .to(|state| remove_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&path, Method::POST);
} }
} }
} }