mirror of
https://gitlab.com/msrd0/gotham-restful.git
synced 2025-02-23 04:52:28 +00:00
Merge branch 'cors' into 'master'
Allow configuring CORS Closes #22 See merge request msrd0/gotham-restful!14
This commit is contained in:
commit
604494651d
8 changed files with 501 additions and 13 deletions
|
@ -45,8 +45,9 @@ paste = "0.1.12"
|
||||||
trybuild = "1.0.26"
|
trybuild = "1.0.26"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["errorlog"]
|
default = ["cors", "errorlog"]
|
||||||
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
||||||
|
cors = []
|
||||||
errorlog = []
|
errorlog = []
|
||||||
database = ["gotham_restful_derive/database", "gotham_middleware_diesel"]
|
database = ["gotham_restful_derive/database", "gotham_middleware_diesel"]
|
||||||
openapi = ["gotham_restful_derive/openapi", "indexmap", "openapiv3"]
|
openapi = ["gotham_restful_derive/openapi", "indexmap", "openapiv3"]
|
||||||
|
|
246
src/cors.rs
Normal file
246
src/cors.rs
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
use crate::matcher::AccessControlRequestMethodMatcher;
|
||||||
|
use gotham::{
|
||||||
|
handler::HandlerFuture,
|
||||||
|
helpers::http::response::create_empty_response,
|
||||||
|
hyper::{
|
||||||
|
header::{
|
||||||
|
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
|
||||||
|
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
|
||||||
|
HeaderMap, HeaderName, HeaderValue
|
||||||
|
},
|
||||||
|
Body, Method, Response, StatusCode
|
||||||
|
},
|
||||||
|
middleware::Middleware,
|
||||||
|
pipeline::chain::PipelineHandleChain,
|
||||||
|
router::builder::*,
|
||||||
|
state::{FromState, State},
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use std::{
|
||||||
|
panic::RefUnwindSafe,
|
||||||
|
pin::Pin
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
Specify the allowed origins of the request. It is up to the browser to check the validity of the
|
||||||
|
origin. This, when sent to the browser, will indicate whether or not the request's origin was
|
||||||
|
allowed to make the request.
|
||||||
|
*/
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum Origin
|
||||||
|
{
|
||||||
|
/// Do not send any `Access-Control-Allow-Origin` headers.
|
||||||
|
None,
|
||||||
|
/// Send `Access-Control-Allow-Origin: *`. Note that browser will not send credentials.
|
||||||
|
Star,
|
||||||
|
/// Set the `Access-Control-Allow-Origin` header to a single origin.
|
||||||
|
Single(String),
|
||||||
|
/// Copy the `Origin` header into the `Access-Control-Allow-Origin` header.
|
||||||
|
Copy
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Origin
|
||||||
|
{
|
||||||
|
fn default() -> Self
|
||||||
|
{
|
||||||
|
Self::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Origin
|
||||||
|
{
|
||||||
|
/// Get the header value for the `Access-Control-Allow-Origin` header.
|
||||||
|
fn header_value(&self, state : &State) -> Option<HeaderValue>
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Self::None => None,
|
||||||
|
Self::Star => Some("*".parse().unwrap()),
|
||||||
|
Self::Single(origin) => Some(origin.parse().unwrap()),
|
||||||
|
Self::Copy => {
|
||||||
|
let headers = HeaderMap::borrow_from(state);
|
||||||
|
headers.get(ORIGIN).map(Clone::clone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
This is the configuration that the CORS handler will follow. Its default configuration is basically
|
||||||
|
not to touch any responses, resulting in the browser's default behaviour.
|
||||||
|
|
||||||
|
To change settings, you need to put this type into gotham's [`State`]:
|
||||||
|
|
||||||
|
```rust,no_run
|
||||||
|
# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State};
|
||||||
|
# use gotham_restful::*;
|
||||||
|
fn main() {
|
||||||
|
let cors = CorsConfig {
|
||||||
|
origin: Origin::Star,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let (chain, pipelines) = single_pipeline(new_pipeline().add(cors).build());
|
||||||
|
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
|
||||||
|
// your routing logic
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This easy approach allows you to have one global cors configuration. If you prefer to have separate
|
||||||
|
configurations for different scopes, you need to register the middleware inside your routing logic:
|
||||||
|
|
||||||
|
```rust,no_run
|
||||||
|
# use gotham::{router::builder::*, pipeline::*, pipeline::set::*, state::State};
|
||||||
|
# use gotham_restful::*;
|
||||||
|
fn main() {
|
||||||
|
let pipelines = new_pipeline_set();
|
||||||
|
|
||||||
|
let cors_a = CorsConfig {
|
||||||
|
origin: Origin::Star,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let (pipelines, chain_a) = pipelines.add(
|
||||||
|
new_pipeline().add(cors_a).build()
|
||||||
|
);
|
||||||
|
|
||||||
|
let cors_b = CorsConfig {
|
||||||
|
origin: Origin::Copy,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let (pipelines, chain_b) = pipelines.add(
|
||||||
|
new_pipeline().add(cors_b).build()
|
||||||
|
);
|
||||||
|
|
||||||
|
let pipeline_set = finalize_pipeline_set(pipelines);
|
||||||
|
gotham::start("127.0.0.1:8080", build_router((), pipeline_set, |route| {
|
||||||
|
// routing without any cors config
|
||||||
|
route.with_pipeline_chain((chain_a, ()), |route| {
|
||||||
|
// routing with cors config a
|
||||||
|
});
|
||||||
|
route.with_pipeline_chain((chain_b, ()), |route| {
|
||||||
|
// routing with cors config b
|
||||||
|
});
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
[`State`]: ../gotham/state/struct.State.html
|
||||||
|
*/
|
||||||
|
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
|
||||||
|
pub struct CorsConfig
|
||||||
|
{
|
||||||
|
/// The allowed origins.
|
||||||
|
pub origin : Origin,
|
||||||
|
/// The allowed headers.
|
||||||
|
pub headers : Vec<HeaderName>,
|
||||||
|
/// The amount of seconds that the preflight request can be cached.
|
||||||
|
pub max_age : u64,
|
||||||
|
/// Whether or not the request may be made with supplying credentials.
|
||||||
|
pub credentials : bool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Middleware for CorsConfig
|
||||||
|
{
|
||||||
|
fn call<Chain>(self, mut state : State, chain : Chain) -> Pin<Box<HandlerFuture>>
|
||||||
|
where
|
||||||
|
Chain : FnOnce(State) -> Pin<Box<HandlerFuture>>
|
||||||
|
{
|
||||||
|
state.put(self);
|
||||||
|
chain(state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
Handle CORS for a non-preflight request. This means manipulating the `res` HTTP headers so that
|
||||||
|
the response is aligned with the `state`'s [`CorsConfig`].
|
||||||
|
|
||||||
|
If you are using the [`Resource`] type (which is the recommended way), you'll never have to call
|
||||||
|
this method. However, if you are writing your own handler method, you might want to call this
|
||||||
|
after your request to add the required CORS headers.
|
||||||
|
|
||||||
|
For further information on CORS, read https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS.
|
||||||
|
|
||||||
|
[`CorsConfig`]: ./struct.CorsConfig.html
|
||||||
|
*/
|
||||||
|
pub fn handle_cors(state : &State, res : &mut Response<Body>)
|
||||||
|
{
|
||||||
|
let config = CorsConfig::try_borrow_from(state);
|
||||||
|
let headers = res.headers_mut();
|
||||||
|
|
||||||
|
// non-preflight requests require the Access-Control-Allow-Origin header
|
||||||
|
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
|
||||||
|
{
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the origin is copied over, we should tell the browser by specifying the Vary header
|
||||||
|
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy))
|
||||||
|
{
|
||||||
|
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
|
||||||
|
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we allow credentials, tell the browser
|
||||||
|
if config.map(|cfg| cfg.credentials).unwrap_or(false)
|
||||||
|
{
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add CORS routing for your path.
|
||||||
|
pub trait CorsRoute<C, P>
|
||||||
|
where
|
||||||
|
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||||
|
P : RefUnwindSafe + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
fn cors(&mut self, path : &str, method : Method);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cors_preflight_handler(state : State) -> (State, Response<Body>)
|
||||||
|
{
|
||||||
|
let config = CorsConfig::try_borrow_from(&state);
|
||||||
|
|
||||||
|
// prepare the response
|
||||||
|
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||||
|
let headers = res.headers_mut();
|
||||||
|
|
||||||
|
// copy the request method over to the response
|
||||||
|
let method = HeaderMap::borrow_from(&state).get(ACCESS_CONTROL_REQUEST_METHOD).unwrap().clone();
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
|
||||||
|
|
||||||
|
// if we allow any headers, put them in
|
||||||
|
if let Some(hdrs) = config.map(|cfg| &cfg.headers)
|
||||||
|
{
|
||||||
|
if hdrs.len() > 0
|
||||||
|
{
|
||||||
|
// TODO do we want to return all headers or just those asked by the browser?
|
||||||
|
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the max age for the preflight cache
|
||||||
|
if let Some(age) = config.map(|cfg| cfg.max_age)
|
||||||
|
{
|
||||||
|
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure the browser knows that this request was based on the method
|
||||||
|
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
|
||||||
|
|
||||||
|
handle_cors(&state, &mut res);
|
||||||
|
(state, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D, C, P> CorsRoute<C, P> for D
|
||||||
|
where
|
||||||
|
D : DrawRoutes<C, P>,
|
||||||
|
C : PipelineHandleChain<P> + Copy + Send + Sync + 'static,
|
||||||
|
P : RefUnwindSafe + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
fn cors(&mut self, path : &str, method : Method)
|
||||||
|
{
|
||||||
|
let matcher = AccessControlRequestMethodMatcher::new(method);
|
||||||
|
self.options(path)
|
||||||
|
.extend_route_matcher(matcher)
|
||||||
|
.to(cors_preflight_handler);
|
||||||
|
}
|
||||||
|
}
|
10
src/lib.rs
10
src/lib.rs
|
@ -285,6 +285,16 @@ pub use auth::{
|
||||||
StaticAuthHandler
|
StaticAuthHandler
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
mod cors;
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
pub use cors::{
|
||||||
|
handle_cors,
|
||||||
|
CorsConfig,
|
||||||
|
CorsRoute,
|
||||||
|
Origin
|
||||||
|
};
|
||||||
|
|
||||||
pub mod matcher;
|
pub mod matcher;
|
||||||
|
|
||||||
#[cfg(feature = "openapi")]
|
#[cfg(feature = "openapi")]
|
||||||
|
|
57
src/matcher/access_control_request_method.rs
Normal file
57
src/matcher/access_control_request_method.rs
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
use gotham::{
|
||||||
|
hyper::{header::{ACCESS_CONTROL_REQUEST_METHOD, HeaderMap}, Method, StatusCode},
|
||||||
|
router::{non_match::RouteNonMatch, route::matcher::RouteMatcher},
|
||||||
|
state::{FromState, State}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A route matcher that checks whether the value of the `Access-Control-Request-Method` header matches the defined value.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use gotham::{helpers::http::response::create_empty_response,
|
||||||
|
/// # hyper::{header::ACCESS_CONTROL_ALLOW_METHODS, Method, StatusCode},
|
||||||
|
/// # router::builder::*
|
||||||
|
/// # };
|
||||||
|
/// # use gotham_restful::matcher::AccessControlRequestMethodMatcher;
|
||||||
|
/// let matcher = AccessControlRequestMethodMatcher::new(Method::PUT);
|
||||||
|
///
|
||||||
|
/// # build_simple_router(|route| {
|
||||||
|
/// // use the matcher for your request
|
||||||
|
/// route.options("/foo")
|
||||||
|
/// .extend_route_matcher(matcher)
|
||||||
|
/// .to(|state| {
|
||||||
|
/// // we know that this is a CORS preflight for a PUT request
|
||||||
|
/// let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
|
||||||
|
/// res.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, "PUT".parse().unwrap());
|
||||||
|
/// (state, res)
|
||||||
|
/// });
|
||||||
|
/// # });
|
||||||
|
/// ```
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
method : Method
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
pub fn new(method : Method) -> Self
|
||||||
|
{
|
||||||
|
Self { method }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteMatcher for AccessControlRequestMethodMatcher
|
||||||
|
{
|
||||||
|
fn is_match(&self, state : &State) -> Result<(), RouteNonMatch>
|
||||||
|
{
|
||||||
|
match HeaderMap::borrow_from(state).get(ACCESS_CONTROL_REQUEST_METHOD)
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.and_then(|str| str.parse::<Method>().ok())
|
||||||
|
{
|
||||||
|
Some(m) if m == self.method => Ok(()),
|
||||||
|
_ => Err(RouteNonMatch::new(StatusCode::NOT_FOUND))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,11 @@ pub use accept::AcceptHeaderMatcher;
|
||||||
mod content_type;
|
mod content_type;
|
||||||
pub use content_type::ContentTypeMatcher;
|
pub use content_type::ContentTypeMatcher;
|
||||||
|
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
mod access_control_request_method;
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
pub use access_control_request_method::AccessControlRequestMethodMatcher;
|
||||||
|
|
||||||
type LookupTable = HashMap<String, Vec<usize>>;
|
type LookupTable = HashMap<String, Vec<usize>>;
|
||||||
|
|
||||||
trait LookupTableFromTypes
|
trait LookupTableFromTypes
|
||||||
|
|
|
@ -6,6 +6,8 @@ use crate::{
|
||||||
Response,
|
Response,
|
||||||
StatusCode
|
StatusCode
|
||||||
};
|
};
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
use crate::CorsRoute;
|
||||||
#[cfg(feature = "openapi")]
|
#[cfg(feature = "openapi")]
|
||||||
use crate::openapi::{
|
use crate::openapi::{
|
||||||
builder::{OpenapiBuilder, OpenapiInfo},
|
builder::{OpenapiBuilder, OpenapiInfo},
|
||||||
|
@ -100,10 +102,16 @@ fn response_from(res : Response, state : &State) -> gotham::hyper::Response<Body
|
||||||
{
|
{
|
||||||
r.headers_mut().insert(CONTENT_TYPE, mime.as_ref().parse().unwrap());
|
r.headers_mut().insert(CONTENT_TYPE, mime.as_ref().parse().unwrap());
|
||||||
}
|
}
|
||||||
if Method::borrow_from(state) != Method::HEAD
|
|
||||||
|
let method = Method::borrow_from(state);
|
||||||
|
if method != Method::HEAD
|
||||||
{
|
{
|
||||||
*r.body_mut() = res.body;
|
*r.body_mut() = res.body;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
crate::cors::handle_cors(state, &mut r);
|
||||||
|
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -385,6 +393,8 @@ macro_rules! implDrawResourceRoutes {
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.to(|state| create_handler::<Handler>(state));
|
.to(|state| create_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::POST);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn change_all<Handler : ResourceChangeAll>(&mut self)
|
fn change_all<Handler : ResourceChangeAll>(&mut self)
|
||||||
|
@ -398,6 +408,8 @@ macro_rules! implDrawResourceRoutes {
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.to(|state| change_all_handler::<Handler>(state));
|
.to(|state| change_all_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::PUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn change<Handler : ResourceChange>(&mut self)
|
fn change<Handler : ResourceChange>(&mut self)
|
||||||
|
@ -407,11 +419,14 @@ macro_rules! implDrawResourceRoutes {
|
||||||
{
|
{
|
||||||
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
let accept_matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||||
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
|
let content_matcher : MaybeMatchContentTypeHeader = Handler::Body::supported_types().into();
|
||||||
self.0.put(&format!("{}/:id", self.1))
|
let path = format!("{}/:id", self.1);
|
||||||
|
self.0.put(&path)
|
||||||
.extend_route_matcher(accept_matcher)
|
.extend_route_matcher(accept_matcher)
|
||||||
.extend_route_matcher(content_matcher)
|
.extend_route_matcher(content_matcher)
|
||||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||||
.to(|state| change_handler::<Handler>(state));
|
.to(|state| change_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&path, Method::PUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
|
fn remove_all<Handler : ResourceRemoveAll>(&mut self)
|
||||||
|
@ -420,15 +435,20 @@ macro_rules! implDrawResourceRoutes {
|
||||||
self.0.delete(&self.1)
|
self.0.delete(&self.1)
|
||||||
.extend_route_matcher(matcher)
|
.extend_route_matcher(matcher)
|
||||||
.to(|state| remove_all_handler::<Handler>(state));
|
.to(|state| remove_all_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&self.1, Method::DELETE);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove<Handler : ResourceRemove>(&mut self)
|
fn remove<Handler : ResourceRemove>(&mut self)
|
||||||
{
|
{
|
||||||
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
let matcher : MaybeMatchAcceptHeader = Handler::Res::accepted_types().into();
|
||||||
self.0.delete(&format!("{}/:id", self.1))
|
let path = format!("{}/:id", self.1);
|
||||||
|
self.0.delete(&path)
|
||||||
.extend_route_matcher(matcher)
|
.extend_route_matcher(matcher)
|
||||||
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
.with_path_extractor::<PathExtractor<Handler::ID>>()
|
||||||
.to(|state| remove_handler::<Handler>(state));
|
.to(|state| remove_handler::<Handler>(state));
|
||||||
|
#[cfg(feature = "cors")]
|
||||||
|
self.0.cors(&path, Method::POST);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
156
tests/cors_handling.rs
Normal file
156
tests/cors_handling.rs
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
#![cfg(feature = "cors")]
|
||||||
|
use gotham::{
|
||||||
|
hyper::{body::Body, client::connect::Connect, header::*, StatusCode},
|
||||||
|
pipeline::{new_pipeline, single::single_pipeline},
|
||||||
|
router::builder::*,
|
||||||
|
test::{Server, TestRequest, TestServer}
|
||||||
|
};
|
||||||
|
use gotham_restful::{CorsConfig, DrawResources, Origin, Raw, Resource, change_all, read_all};
|
||||||
|
use itertools::Itertools;
|
||||||
|
use mime::TEXT_PLAIN;
|
||||||
|
|
||||||
|
#[derive(Resource)]
|
||||||
|
#[resource(read_all, change_all)]
|
||||||
|
struct FooResource;
|
||||||
|
|
||||||
|
#[read_all(FooResource)]
|
||||||
|
fn read_all()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
#[change_all(FooResource)]
|
||||||
|
fn change_all(_body : Raw<Vec<u8>>)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_server(cfg : CorsConfig) -> TestServer
|
||||||
|
{
|
||||||
|
let (chain, pipeline) = single_pipeline(new_pipeline().add(cfg).build());
|
||||||
|
TestServer::new(build_router(chain, pipeline, |router| {
|
||||||
|
router.resource::<FooResource>("/foo")
|
||||||
|
})).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_response<TS, C>(req : TestRequest<TS, C>, origin : Option<&str>, vary : Option<&str>, credentials : bool)
|
||||||
|
where
|
||||||
|
TS : Server + 'static,
|
||||||
|
C : Connect + Clone + Send + Sync + 'static
|
||||||
|
{
|
||||||
|
let res = req.with_header(ORIGIN, "http://example.org".parse().unwrap()).perform().unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NO_CONTENT);
|
||||||
|
let headers = res.headers();
|
||||||
|
println!("{}", headers.keys().join(","));
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin);
|
||||||
|
assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), vary);
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials);
|
||||||
|
assert!(headers.get(ACCESS_CONTROL_MAX_AGE).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_preflight(server : &TestServer, method : &str, origin : Option<&str>, vary : &str, credentials : bool, max_age : u64)
|
||||||
|
{
|
||||||
|
let res = server.client().options("http://example.org/foo")
|
||||||
|
.with_header(ACCESS_CONTROL_REQUEST_METHOD, method.parse().unwrap())
|
||||||
|
.with_header(ORIGIN, "http://example.org".parse().unwrap())
|
||||||
|
.perform().unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NO_CONTENT);
|
||||||
|
let headers = res.headers();
|
||||||
|
println!("{}", headers.keys().join(","));
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).and_then(|value| value.to_str().ok()).as_deref(), Some(method));
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin);
|
||||||
|
assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), Some(vary));
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials);
|
||||||
|
assert_eq!(headers.get(ACCESS_CONTROL_MAX_AGE).and_then(|value| value.to_str().ok()).and_then(|value| value.parse().ok()), Some(max_age));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_origin_none()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::None,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 0);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), None, None, false);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_origin_star()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::Star,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", Some("*"), "Access-Control-Request-Method", false, 0);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), Some("*"), None, false);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("*"), None, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_origin_single()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::Single("https://foo.com".to_owned()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", Some("https://foo.com"), "Access-Control-Request-Method", false, 0);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), Some("https://foo.com"), None, false);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("https://foo.com"), None, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_origin_copy()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::Copy,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", Some("http://example.org"), "Access-Control-Request-Method,Origin", false, 0);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), Some("http://example.org"), Some("Origin"), false);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("http://example.org"), Some("Origin"), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_credentials()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::None,
|
||||||
|
credentials: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", None, "Access-Control-Request-Method", true, 0);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), None, None, true);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cors_max_age()
|
||||||
|
{
|
||||||
|
let cfg = CorsConfig {
|
||||||
|
origin: Origin::None,
|
||||||
|
max_age: 31536000,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let server = test_server(cfg);
|
||||||
|
|
||||||
|
test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 31536000);
|
||||||
|
|
||||||
|
test_response(server.client().get("http://example.org/foo"), None, None, false);
|
||||||
|
test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false);
|
||||||
|
}
|
|
@ -1,8 +1,4 @@
|
||||||
#[cfg(feature = "openapi")]
|
#![cfg(feature = "openapi")]
|
||||||
mod openapi_supports_scope
|
|
||||||
{
|
|
||||||
|
|
||||||
|
|
||||||
use gotham::{
|
use gotham::{
|
||||||
router::builder::*,
|
router::builder::*,
|
||||||
test::TestServer
|
test::TestServer
|
||||||
|
@ -29,7 +25,7 @@ fn read_all() -> Raw<&'static [u8]>
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test()
|
fn openapi_supports_scope()
|
||||||
{
|
{
|
||||||
let info = OpenapiInfo {
|
let info = OpenapiInfo {
|
||||||
title: "Test".to_owned(),
|
title: "Test".to_owned(),
|
||||||
|
@ -54,6 +50,3 @@ fn test()
|
||||||
test_get_response(&server, "http://localhost/bar/baz/foo3", RESPONSE);
|
test_get_response(&server, "http://localhost/bar/baz/foo3", RESPONSE);
|
||||||
test_get_response(&server, "http://localhost/foo4", RESPONSE);
|
test_get_response(&server, "http://localhost/foo4", RESPONSE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // mod test
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue