mirror of
https://gitlab.com/msrd0/gotham-restful.git
synced 2025-02-23 04:52:28 +00:00
cors preflight
This commit is contained in:
parent
748bf65d3e
commit
f20c768d02
6 changed files with 182 additions and 12 deletions
|
@ -45,7 +45,7 @@ paste = "0.1.12"
|
||||||
trybuild = "1.0.26"
|
trybuild = "1.0.26"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["errorlog"]
|
default = ["cors", "errorlog"]
|
||||||
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
||||||
cors = []
|
cors = []
|
||||||
errorlog = []
|
errorlog = []
|
||||||
|
|
112
src/cors.rs
112
src/cors.rs
|
@ -1,13 +1,25 @@
|
||||||
|
use crate::matcher::AccessControlRequestMethodMatcher;
|
||||||
use gotham::{
|
use gotham::{
|
||||||
handler::HandlerFuture,
|
handler::HandlerFuture,
|
||||||
|
helpers::http::response::create_empty_response,
|
||||||
hyper::{
|
hyper::{
|
||||||
header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN, HeaderMap, HeaderValue},
|
header::{
|
||||||
Body, Method, Response
|
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
|
||||||
|
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
|
||||||
|
HeaderMap, HeaderName, HeaderValue
|
||||||
|
},
|
||||||
|
Body, Method, Response, StatusCode
|
||||||
},
|
},
|
||||||
middleware::Middleware,
|
middleware::Middleware,
|
||||||
|
pipeline::chain::PipelineHandleChain,
|
||||||
|
router::builder::*,
|
||||||
state::{FromState, State},
|
state::{FromState, State},
|
||||||
};
|
};
|
||||||
use std::pin::Pin;
|
use itertools::Itertools;
|
||||||
|
use std::{
|
||||||
|
panic::RefUnwindSafe,
|
||||||
|
pin::Pin
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Specify the allowed origins of the request. It is up to the browser to check the validity of the
|
Specify the allowed origins of the request. It is up to the browser to check the validity of the
|
||||||
|
@ -63,7 +75,8 @@ To change settings, you need to put this type into gotham's [`State`]:
|
||||||
# use gotham_restful::*;
|
# use gotham_restful::*;
|
||||||
fn main() {
|
fn main() {
|
||||||
let cors = CorsConfig {
|
let cors = CorsConfig {
|
||||||
origin: Origin::Star
|
origin: Origin::Star,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
|
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
|
||||||
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
|
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
|
||||||
|
@ -82,14 +95,16 @@ fn main() {
|
||||||
let pipelines = new_pipeline_set();
|
let pipelines = new_pipeline_set();
|
||||||
|
|
||||||
let cors_a = CorsConfig {
|
let cors_a = CorsConfig {
|
||||||
origin: Origin::Star
|
origin: Origin::Star,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let (pipelines, chain_a) = pipelines.add(
|
let (pipelines, chain_a) = pipelines.add(
|
||||||
new_pipeline().add(cors_a).build()
|
new_pipeline().add(cors_a).build()
|
||||||
);
|
);
|
||||||
|
|
||||||
let cors_b = CorsConfig {
|
let cors_b = CorsConfig {
|
||||||
origin: Origin::Copy
|
origin: Origin::Copy,
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
let (pipelines, chain_b) = pipelines.add(
|
let (pipelines, chain_b) = pipelines.add(
|
||||||
new_pipeline().add(cors_b).build()
|
new_pipeline().add(cors_b).build()
|
||||||
|
@ -113,7 +128,14 @@ fn main() {
|
||||||
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
|
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
|
||||||
pub struct CorsConfig
|
pub struct CorsConfig
|
||||||
{
|
{
|
||||||
pub origin : Origin
|
/// The allowed origins.
|
||||||
|
pub origin : Origin,
|
||||||
|
/// The allowed headers.
|
||||||
|
pub headers : Vec<HeaderName>,
|
||||||
|
/// The amount of seconds that the preflight request can be cached.
|
||||||
|
pub max_age : u64,
|
||||||
|
/// Whether or not the request may be made with supplying credentials.
|
||||||
|
pub credentials : bool
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Middleware for CorsConfig
|
impl Middleware for CorsConfig
|
||||||
|
@ -141,12 +163,84 @@ For further information on CORS, read https://developer.mozilla.org/en-US/docs/W
|
||||||
*/
|
*/
|
||||||
pub fn handle_cors(state : &State, res : &mut Response<Body>)
|
pub fn handle_cors(state : &State, res : &mut Response<Body>)
|
||||||
{
|
{
|
||||||
let method = Method::borrow_from(state);
|
|
||||||
let config = CorsConfig::try_borrow_from(state);
|
let config = CorsConfig::try_borrow_from(state);
|
||||||
|
let headers = res.headers_mut();
|
||||||
|
|
||||||
// non-preflight requests require nothing other than the Access-Control-Allow-Origin header
|
// non-preflight requests require nothing other than the Access-Control-Allow-Origin header
|
||||||
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
|
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
|
||||||
{
|
{
|
||||||
res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
|
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the origin is copied over, we should tell the browser by specifying the Vary header
|
||||||
|
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy))
|
||||||
|
{
|
||||||
|
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
|
||||||
|
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we allow credentials, tell the browser
|
||||||
|
if config.map(|cfg| cfg.credentials).unwrap_or(false)
|
||||||
|
{
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add CORS routing for your path.
|
||||||
|
pub trait CorsRoute<C, P>
|
||||||
|
where
|
||||||
|
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||||
|
P : RefUnwindSafe + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
fn cors(&mut self, path : &str, method : Method);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cors_preflight_handler(state : State) -> (State, Response<Body>)
|
||||||
|
{
|
||||||
|
let config = CorsConfig::try_borrow_from(&state);
|
||||||
|
|
||||||
|
// prepare the response
|
||||||
|
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||||
|
let headers = res.headers_mut();
|
||||||
|
|
||||||
|
// copy the request method over to the response
|
||||||
|
let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone();
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
|
||||||
|
|
||||||
|
// if we allow any headers, put them in
|
||||||
|
if let Some(hdrs) = config.map(|cfg| &cfg.headers)
|
||||||
|
{
|
||||||
|
if hdrs.len() > 0
|
||||||
|
{
|
||||||
|
// TODO do we want to return all headers or just those asked by the browser?
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the max age for the preflight cache
|
||||||
|
if let Some(age) = config.map(|cfg| cfg.max_age)
|
||||||
|
{
|
||||||
|
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure the browser knows that this request was based on the method
|
||||||
|
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
|
||||||
|
|
||||||
|
handle_cors(&state, &mut res);
|
||||||
|
(state, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D, C, P> CorsRoute<C, P> for D
|
||||||
|
where
|
||||||
|
D : DrawRoutes<C, P>,
|
||||||
|
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||||
|
P : RefUnwindSafe + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
fn cors(&mut self, path : &str, method : Method)
|
||||||
|
{
|
||||||
|
let matcher = AccessControlRequestMethodMatcher::new(method);
|
||||||
|
self.options(path)
|
||||||
|
.extend_route_matcher(matcher)
|
||||||
|
.to(cors_preflight_handler);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -291,6 +291,7 @@ mod cors;
|
||||||
pub use cors::{
|
pub use cors::{
|
||||||
handle_cors,
|
handle_cors,
|
||||||
CorsConfig,
|
CorsConfig,
|
||||||
|
CorsRoute,
|
||||||
Origin
|
Origin
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
57
src/matcher/access_control_request_method.rs
Normal file
57
src/matcher/access_control_request_method.rs
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
use gotham::{
|
||||||
|
hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode},
|
||||||
|
router::{non_match::RouteNonMatch, route::matcher::RouteMatcher},
|
||||||
|
state::{FromState, State}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use gotham::{helpers::http::response::create_empty_response,
|
||||||
|
/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode},
|
||||||
|
/// # router::builder::*
|
||||||
|
/// # };
|
||||||
|
/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher;
|
||||||
|
/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT);
|
||||||
|
///
|
||||||
|
/// # build_simple_router(|route| {
|
||||||
|
/// // use the matcher for your request
|
||||||
|
/// route.options("/foo")
|
||||||
|
/// .extend_route_matcher(matcher)
|
||||||
|
/// .to(|state| {
|
||||||
|
/// // we know that this is a CORS preflight for a PUT request
|
||||||
|
/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||||
|
/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap());
|
||||||
|
/// (state, res)
|
||||||
|
/// });
|
||||||
|
/// # });
|
||||||
|
/// ```
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
method : Method
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
pub fn new(method : Method) -> Self
|
||||||
|
{
|
||||||
|
Self { method }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteMatcher for AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
fn is_match(&self, state : &State) -> Result<(), RouteNonMatch>
|
||||||
|
{
|
||||||
|
match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD)
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.and_then(|str| str.parse::<Method>().ok())
|
||||||
|
{
|
||||||
|
Some(m) if m == self.method => Ok(()),
|
||||||
|
_ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,10 @@ pub use accept::AcceptHeaderMatcher;
|
||||||
mod content_type;
|
mod content_type;
|
||||||
pub use content_type::ContentTypeMatcher;
|
pub use content_type::ContentTypeMatcher;
|
||||||
|
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
mod access_control_request_method;
|
||||||
|
pub use access_control_request_method::AccessControlRequestMethodMatcher;
|
||||||
|
|
||||||
type LookupTable = HashMap<String, Vec<usize>>;
|
type LookupTable = HashMap<String, Vec<usize>>;
|
||||||
|
|
||||||
trait LookupTableFromTypes
|
trait LookupTableFromTypes
|
||||||
|
|
|
@ -6,6 +6,8 @@ use crate::{
|
||||||
Response,
|
Response,
|
||||||
StatusCode
|
StatusCode
|
||||||
};
|
};
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
use crate::CorsRoute;
|
||||||
#[cfg(feature = "openapi")]
|
#[cfg(feature = "openapi")]
|
||||||
use crate::openapi::{
|
use crate::openapi::{
|
||||||
builder::{OpenapiBuilder, OpenapiInfo},
|
builder::{OpenapiBuilder, OpenapiInfo},
|
||||||
|
@ -391,6 +393,8 @@ macro_rules! implDrawResourceRoutes {
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.to(|state| create_handler::<Handler>(state));
|
.to(|state| create_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::POST);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn change_all<Handler : ResourceChangeAll>(&mut self)
|
fn change_all<Handler : ResourceChangeAll>(&mut self)
|
||||||
|
@ -404,6 +408,8 @@ macro_rules! implDrawResourceRoutes {
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.to(|state| change_all_handler::<Handler>(state));
|
.to(|state| change_all_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::PUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn change<Handler : ResourceChange>(&mut self)
|
fn change<Handler : ResourceChange>(&mut self)
|
||||||
|
@ -413,11 +419,14 @@ macro_rules! implDrawResourceRoutes {
|
||||||
{
|
{
|
||||||
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||||
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
|
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
|
||||||
self.0.put(&format!("{}/:id", self.1))
|
let path = format!("{}/:id", self.1);
|
||||||
|
self.0.put(&path)
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||||
.to(|state| change_handler::<Handler>(state));
|
.to(|state| change_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&path, Method::PUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
|
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
|
||||||
|
@ -426,15 +435,20 @@ macro_rules! implDrawResourceRoutes {
|
||||||
self.0.delete(&self.1)
|
self.0.delete(&self.1)
|
||||||
.extend_route_matcher(matcher)
|
.extend_route_matcher(matcher)
|
||||||
.to(|state| remove_all_handler::<Handler>(state));
|
.to(|state| remove_all_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::DELETE);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove<Handler : ResourceRemove>(&mut self)
|
fn remove<Handler : ResourceRemove>(&mut self)
|
||||||
{
|
{
|
||||||
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||||
self.0.delete(&format!("{}/:id", self.1))
|
let path = format!("{}/:id", self.1);
|
||||||
|
self.0.delete(&path)
|
||||||
.extend_route_matcher(matcher)
|
.extend_route_matcher(matcher)
|
||||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||||
.to(|state| remove_handler::<Handler>(state));
|
.to(|state| remove_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&path, Method::POST);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue