mirror of
https://gitlab.com/msrd0/gotham-restful.git
synced 2025-02-22 20:52:27 +00:00
cors preflight
This commit is contained in:
parent
748bf65d3e
commit
f20c768d02
6 changed files with 182 additions and 12 deletions
|
@ -45,7 +45,7 @@ paste = "0.1.12"
|
|||
trybuild = "1.0.26"
|
||||
|
||||
[features]
|
||||
default = ["errorlog"]
|
||||
default = ["cors", "errorlog"]
|
||||
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
||||
cors = []
|
||||
errorlog = []
|
||||
|
|
112
src/cors.rs
112
src/cors.rs
|
@ -1,13 +1,25 @@
|
|||
use crate::matcher::AccessControlRequestMethodMatcher;
|
||||
use gotham::{
|
||||
handler::HandlerFuture,
|
||||
helpers::http::response::create_empty_response,
|
||||
hyper::{
|
||||
header::{ACCESS_CONTROL_ALLOW_ORIGIN, ORIGIN, HeaderMap, HeaderValue},
|
||||
Body, Method, Response
|
||||
header::{
|
||||
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
|
||||
HeaderMap, HeaderName, HeaderValue
|
||||
},
|
||||
Body, Method, Response, StatusCode
|
||||
},
|
||||
middleware::Middleware,
|
||||
pipeline::chain::PipelineHandleChain,
|
||||
router::builder::*,
|
||||
state::{FromState, State},
|
||||
};
|
||||
use std::pin::Pin;
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
panic::RefUnwindSafe,
|
||||
pin::Pin
|
||||
};
|
||||
|
||||
/**
|
||||
Specify the allowed origins of the request. It is up to the browser to check the validity of the
|
||||
|
@ -63,7 +75,8 @@ To change settings, you need to put this type into gotham's [`State`]:
|
|||
# use gotham_restful::*;
|
||||
fn main() {
|
||||
let cors = CorsConfig {
|
||||
origin: Origin::Star
|
||||
origin: Origin::Star,
|
||||
..Default::default()
|
||||
};
|
||||
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
|
||||
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
|
||||
|
@ -82,14 +95,16 @@ fn main() {
|
|||
let pipelines = new_pipeline_set();
|
||||
|
||||
let cors_a = CorsConfig {
|
||||
origin: Origin::Star
|
||||
origin: Origin::Star,
|
||||
..Default::default()
|
||||
};
|
||||
let (pipelines, chain_a) = pipelines.add(
|
||||
new_pipeline().add(cors_a).build()
|
||||
);
|
||||
|
||||
let cors_b = CorsConfig {
|
||||
origin: Origin::Copy
|
||||
origin: Origin::Copy,
|
||||
..Default::default()
|
||||
};
|
||||
let (pipelines, chain_b) = pipelines.add(
|
||||
new_pipeline().add(cors_b).build()
|
||||
|
@ -113,7 +128,14 @@ fn main() {
|
|||
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
|
||||
pub struct CorsConfig
|
||||
{
|
||||
pub origin : Origin
|
||||
/// The allowed origins.
|
||||
pub origin : Origin,
|
||||
/// The allowed headers.
|
||||
pub headers : Vec<HeaderName>,
|
||||
/// The amount of seconds that the preflight request can be cached.
|
||||
pub max_age : u64,
|
||||
/// Whether or not the request may be made with supplying credentials.
|
||||
pub credentials : bool
|
||||
}
|
||||
|
||||
impl Middleware for CorsConfig
|
||||
|
@ -141,12 +163,84 @@ For further information on CORS, read https://developer.mozilla.org/en-US/docs/W
|
|||
*/
|
||||
pub fn handle_cors(state : &State, res : &mut Response<Body>)
|
||||
{
|
||||
let method = Method::borrow_from(state);
|
||||
let config = CorsConfig::try_borrow_from(state);
|
||||
let headers = res.headers_mut();
|
||||
|
||||
// non-preflight requests require nothing other than the Access-Control-Allow-Origin header
|
||||
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
|
||||
{
|
||||
res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
|
||||
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
|
||||
}
|
||||
|
||||
// if the origin is copied over, we should tell the browser by specifying the Vary header
|
||||
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy))
|
||||
{
|
||||
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
|
||||
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
|
||||
}
|
||||
|
||||
// if we allow credentials, tell the browser
|
||||
if config.map(|cfg| cfg.credentials).unwrap_or(false)
|
||||
{
|
||||
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
/// Add CORS routing for your path.
|
||||
pub trait CorsRoute<C, P>
|
||||
where
|
||||
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||
P : RefUnwindSafe + Send + Sync + 'static
|
||||
{
|
||||
fn cors(&mut self, path : &str, method : Method);
|
||||
}
|
||||
|
||||
fn cors_preflight_handler(state : State) -> (State, Response<Body>)
|
||||
{
|
||||
let config = CorsConfig::try_borrow_from(&state);
|
||||
|
||||
// prepare the response
|
||||
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||
let headers = res.headers_mut();
|
||||
|
||||
// copy the request method over to the response
|
||||
let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone();
|
||||
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
|
||||
|
||||
// if we allow any headers, put them in
|
||||
if let Some(hdrs) = config.map(|cfg| &cfg.headers)
|
||||
{
|
||||
if hdrs.len() > 0
|
||||
{
|
||||
// TODO do we want to return all headers or just those asked by the browser?
|
||||
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// set the max age for the preflight cache
|
||||
if let Some(age) = config.map(|cfg| cfg.max_age)
|
||||
{
|
||||
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
|
||||
}
|
||||
|
||||
// make sure the browser knows that this request was based on the method
|
||||
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
|
||||
|
||||
handle_cors(&state, &mut res);
|
||||
(state, res)
|
||||
}
|
||||
|
||||
impl<D, C, P> CorsRoute<C, P> for D
|
||||
where
|
||||
D : DrawRoutes<C, P>,
|
||||
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||
P : RefUnwindSafe + Send + Sync + 'static
|
||||
{
|
||||
fn cors(&mut self, path : &str, method : Method)
|
||||
{
|
||||
let matcher = AccessControlRequestMethodMatcher::new(method);
|
||||
self.options(path)
|
||||
.extend_route_matcher(matcher)
|
||||
.to(cors_preflight_handler);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -291,6 +291,7 @@ mod cors;
|
|||
pub use cors::{
|
||||
handle_cors,
|
||||
CorsConfig,
|
||||
CorsRoute,
|
||||
Origin
|
||||
};
|
||||
|
||||
|
|
57
src/matcher/access_control_request_method.rs
Normal file
57
src/matcher/access_control_request_method.rs
Normal file
|
@ -0,0 +1,57 @@
|
|||
use gotham::{
|
||||
hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode},
|
||||
router::{non_match::RouteNonMatch, route::matcher::RouteMatcher},
|
||||
state::{FromState, State}
|
||||
};
|
||||
|
||||
/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use gotham::{helpers::http::response::create_empty_response,
|
||||
/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode},
|
||||
/// # router::builder::*
|
||||
/// # };
|
||||
/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher;
|
||||
/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT);
|
||||
///
|
||||
/// # build_simple_router(|route| {
|
||||
/// // use the matcher for your request
|
||||
/// route.options("/foo")
|
||||
/// .extend_route_matcher(matcher)
|
||||
/// .to(|state| {
|
||||
/// // we know that this is a CORS preflight for a PUT request
|
||||
/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||
/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap());
|
||||
/// (state, res)
|
||||
/// });
|
||||
/// # });
|
||||
/// ```
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AccessControlRequestMethodMatcher
|
||||
{
|
||||
method : Method
|
||||
}
|
||||
|
||||
impl AccessControlRequestMethodMatcher
|
||||
{
|
||||
pub fn new(method : Method) -> Self
|
||||
{
|
||||
Self { method }
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteMatcher for AccessControlRequestMethodMatcher
|
||||
{
|
||||
fn is_match(&self, state : &State) -> Result<(), RouteNonMatch>
|
||||
{
|
||||
match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|str| str.parse::<Method>().ok())
|
||||
{
|
||||
Some(m) if m == self.method => Ok(()),
|
||||
_ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,6 +8,10 @@ pub use accept::AcceptHeaderMatcher;
|
|||
mod content_type;
|
||||
pub use content_type::ContentTypeMatcher;
|
||||
|
||||
#[cfg(feature = "cors")]
|
||||
mod access_control_request_method;
|
||||
pub use access_control_request_method::AccessControlRequestMethodMatcher;
|
||||
|
||||
type LookupTable = HashMap<String, Vec<usize>>;
|
||||
|
||||
trait LookupTableFromTypes
|
||||
|
|
|
@ -6,6 +6,8 @@ use crate::{
|
|||
Response,
|
||||
StatusCode
|
||||
};
|
||||
#[cfg(feature = "cors")]
|
||||
use crate::CorsRoute;
|
||||
#[cfg(feature = "openapi")]
|
||||
use crate::openapi::{
|
||||
builder::{OpenapiBuilder, OpenapiInfo},
|
||||
|
@ -391,6 +393,8 @@ macro_rules! implDrawResourceRoutes {
|
|||
.extend_route_matcher(accept_matcher)
|
||||
.extend_route_matcher(content_matcher)
|
||||
.to(|state| create_handler::<Handler>(state));
|
||||
#[cfg(feature = "cors")]
|
||||
self.0.cors(&self.1, Method::POST);
|
||||
}
|
||||
|
||||
fn change_all<Handler : ResourceChangeAll>(&mut self)
|
||||
|
@ -404,6 +408,8 @@ macro_rules! implDrawResourceRoutes {
|
|||
.extend_route_matcher(accept_matcher)
|
||||
.extend_route_matcher(content_matcher)
|
||||
.to(|state| change_all_handler::<Handler>(state));
|
||||
#[cfg(feature = "cors")]
|
||||
self.0.cors(&self.1, Method::PUT);
|
||||
}
|
||||
|
||||
fn change<Handler : ResourceChange>(&mut self)
|
||||
|
@ -413,11 +419,14 @@ macro_rules! implDrawResourceRoutes {
|
|||
{
|
||||
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
|
||||
self.0.put(&format!("{}/:id", self.1))
|
||||
let path = format!("{}/:id", self.1);
|
||||
self.0.put(&path)
|
||||
.extend_route_matcher(accept_matcher)
|
||||
.extend_route_matcher(content_matcher)
|
||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||
.to(|state| change_handler::<Handler>(state));
|
||||
#[cfg(feature = "cors")]
|
||||
self.0.cors(&path, Method::PUT);
|
||||
}
|
||||
|
||||
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
|
||||
|
@ -426,15 +435,20 @@ macro_rules! implDrawResourceRoutes {
|
|||
self.0.delete(&self.1)
|
||||
.extend_route_matcher(matcher)
|
||||
.to(|state| remove_all_handler::<Handler>(state));
|
||||
#[cfg(feature = "cors")]
|
||||
self.0.cors(&self.1, Method::DELETE);
|
||||
}
|
||||
|
||||
fn remove<Handler : ResourceRemove>(&mut self)
|
||||
{
|
||||
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||
self.0.delete(&format!("{}/:id", self.1))
|
||||
let path = format!("{}/:id", self.1);
|
||||
self.0.delete(&path)
|
||||
.extend_route_matcher(matcher)
|
||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||
.to(|state| remove_handler::<Handler>(state));
|
||||
#[cfg(feature = "cors")]
|
||||
self.0.cors(&path, Method::POST);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue