1
0
Fork 0
mirror of https://gitlab.com/msrd0/gotham-restful.git synced 2025-02-22 20:52:27 +00:00

cors preflight

This commit is contained in:
Dominic 2020-05-14 23:30:59 +02:00
parent 748bf65d3e
commit f20c768d02
Signed by: msrd0
GPG key ID: DCC8C247452E98F9
6 changed files with 182 additions and 12 deletions

View file

@ -45,7 +45,7 @@ paste = "0.1.12"
trybuild = "1.0.26"
[features]
default = ["errorlog"]
default = ["cors", "errorlog"]
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
cors = []
errorlog = []

View file

@ -1,13 +1,25 @@
use crate::matcher::AccessControlRequestMethodMatcher;
use gotham::{
handler::HandlerFuture,
helpers::http::response::create_empty_response,
hyper::{
header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN, HeaderMap, HeaderValue},
Body, Method, Response
header::{
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
HeaderMap, HeaderName, HeaderValue
},
Body, Method, Response, StatusCode
},
middleware::Middleware,
pipeline::chain::PipelineHandleChain,
router::builder::*,
state::{FromState, State},
};
use std::pin::Pin;
use itertools::Itertools;
use std::{
panic::RefUnwindSafe,
pin::Pin
};
/**
Specify the allowed origins of the request. It is up to the browser to check the validity of the
@ -63,7 +75,8 @@ To change settings, you need to put this type into gotham's [`State`]:
# use gotham_restful::*;
fn main() {
let cors = CorsConfig {
origin: Origin::Star
origin: Origin::Star,
..Default::default()
};
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
@ -82,14 +95,16 @@ fn main() {
let pipelines = new_pipeline_set();
let cors_a = CorsConfig {
origin: Origin::Star
origin: Origin::Star,
..Default::default()
};
let (pipelines, chain_a) = pipelines.add(
new_pipeline().add(cors_a).build()
);
let cors_b = CorsConfig {
origin: Origin::Copy
origin: Origin::Copy,
..Default::default()
};
let (pipelines, chain_b) = pipelines.add(
new_pipeline().add(cors_b).build()
@ -113,7 +128,14 @@ fn main() {
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
pub struct CorsConfig
{
pub origin : Origin
/// The allowed origins.
pub origin : Origin,
/// The allowed headers.
pub headers : Vec<HeaderName>,
/// The amount of seconds that the preflight request can be cached.
pub max_age : u64,
/// Whether or not the request may be made with supplying credentials.
pub credentials : bool
}
impl Middleware for CorsConfig
@ -141,12 +163,84 @@ For further information on CORS, read https://developer.mozilla.org/en-US/docs/W
*/
pub fn handle_cors(state : &State, res : &mut Response<Body>)
{
let method = Method::borrow_from(state);
let config = CorsConfig::try_borrow_from(state);
let headers = res.headers_mut();
// non-preflight requests require nothing other than the Access-Control-Allow-Origin header
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
{
res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
}
// if the origin is copied over, we should tell the browser by specifying the Vary header
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy))
{
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
}
// if we allow credentials, tell the browser
if config.map(|cfg| cfg.credentials).unwrap_or(false)
{
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
}
}
/// Add CORS routing for your path.
pub trait CorsRoute<C, P>
where
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P : RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path : &str, method : Method);
}
fn cors_preflight_handler(state : State) -> (State, Response<Body>)
{
let config = CorsConfig::try_borrow_from(&state);
// prepare the response
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
let headers = res.headers_mut();
// copy the request method over to the response
let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone();
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
// if we allow any headers, put them in
if let Some(hdrs) = config.map(|cfg| &cfg.headers)
{
if hdrs.len() > 0
{
// TODO do we want to return all headers or just those asked by the browser?
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
}
}
// set the max age for the preflight cache
if let Some(age) = config.map(|cfg| cfg.max_age)
{
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
}
// make sure the browser knows that this request was based on the method
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
handle_cors(&state, &mut res);
(state, res)
}
impl<D, C, P> CorsRoute<C, P> for D
where
D : DrawRoutes<C, P>,
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P : RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path : &str, method : Method)
{
let matcher = AccessControlRequestMethodMatcher::new(method);
self.options(path)
.extend_route_matcher(matcher)
.to(cors_preflight_handler);
}
}

View file

@ -291,6 +291,7 @@ mod cors;
pub use cors::{
handle_cors,
CorsConfig,
CorsRoute,
Origin
};

View file

@ -0,0 +1,57 @@
use gotham::{
hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode},
router::{non_match::RouteNonMatch, route::matcher::RouteMatcher},
state::{FromState, State}
};
/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value.
///
/// Usage:
///
/// ```rust
/// # use gotham::{helpers::http::response::create_empty_response,
/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode},
/// # router::builder::*
/// # };
/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher;
/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT);
///
/// # build_simple_router(|route| {
/// // use the matcher for your request
/// route.options("/foo")
/// .extend_route_matcher(matcher)
/// .to(|state| {
/// // we know that this is a CORS preflight for a PUT request
/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap());
/// (state, res)
/// });
/// # });
/// ```
#[derive(Clone, Debug)]
pub struct AccessControlRequestMethodMatcher
{
method : Method
}
impl AccessControlRequestMethodMatcher
{
pub fn new(method : Method) -> Self
{
Self { method }
}
}
impl RouteMatcher for AccessControlRequestMethodMatcher
{
fn is_match(&self, state : &State) -> Result<(), RouteNonMatch>
{
match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD)
.and_then(|value| value.to_str().ok())
.and_then(|str| str.parse::<Method>().ok())
{
Some(m) if m == self.method => Ok(()),
_ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND))
}
}
}

View file

@ -8,6 +8,10 @@ pub use accept::AcceptHeaderMatcher;
mod content_type;
pub use content_type::ContentTypeMatcher;
#[cfg(feature = "cors")]
mod access_control_request_method;
pub use access_control_request_method::AccessControlRequestMethodMatcher;
type LookupTable = HashMap<String, Vec<usize>>;
trait LookupTableFromTypes

View file

@ -6,6 +6,8 @@ use crate::{
Response,
StatusCode
};
#[cfg(feature = "cors")]
use crate::CorsRoute;
#[cfg(feature = "openapi")]
use crate::openapi::{
builder::{OpenapiBuilder, OpenapiInfo},
@ -391,6 +393,8 @@ macro_rules! implDrawResourceRoutes {
.extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher)
.to(|state| create_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::POST);
}
fn change_all<Handler : ResourceChangeAll>(&mut self)
@ -404,6 +408,8 @@ macro_rules! implDrawResourceRoutes {
.extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher)
.to(|state| change_all_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::PUT);
}
fn change<Handler : ResourceChange>(&mut self)
@ -413,11 +419,14 @@ macro_rules! implDrawResourceRoutes {
{
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
self.0.put(&format!("{}/:id", self.1))
let path = format!("{}/:id", self.1);
self.0.put(&path)
.extend_route_matcher(accept_matcher)
.extend_route_matcher(content_matcher)
.with_path_extractor::<PathExtractor<Handler::ID>>()
.to(|state| change_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&path, Method::PUT);
}
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
@ -426,15 +435,20 @@ macro_rules! implDrawResourceRoutes {
self.0.delete(&self.1)
.extend_route_matcher(matcher)
.to(|state| remove_all_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&self.1, Method::DELETE);
}
fn remove<Handler : ResourceRemove>(&mut self)
{
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
self.0.delete(&format!("{}/:id", self.1))
let path = format!("{}/:id", self.1);
self.0.delete(&path)
.extend_route_matcher(matcher)
.with_path_extractor::<PathExtractor<Handler::ID>>()
.to(|state| remove_handler::<Handler>(state));
#[cfg(feature = "cors")]
self.0.cors(&path, Method::POST);
}
}
}