1
0
Fork 0
mirror of https://gitlab.com/msrd0/gotham-restful.git synced 2025-04-20 06:54:46 +00:00

support copying headers in cors preflight requests

This commit is contained in:
Dominic 2021-01-01 16:44:55 +01:00
parent b005346e54
commit 766bc9d17d
Signed by: msrd0
GPG key ID: DCC8C247452E98F9
5 changed files with 188 additions and 47 deletions

View file

@ -5,7 +5,7 @@ use gotham::{
header::{
HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE,
ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY
},
Body, Method, Response, StatusCode
},
@ -14,7 +14,6 @@ use gotham::{
router::{builder::*, route::matcher::AccessControlRequestMethodMatcher},
state::{FromState, State}
};
use itertools::Itertools;
use std::{panic::RefUnwindSafe, pin::Pin};
/**
@ -53,6 +52,52 @@ impl Origin {
}
}
}
/// Returns true if the `Vary` header has to include `Origin`.
fn varies(&self) -> bool {
matches!(self, Self::Copy)
}
}
/**
Specify the allowed headers of the request. It is up to the browser to check that only the allowed
headers are sent with the request.
*/
#[derive(Clone, Debug)]
pub enum Headers {
/// Do not send any `Access-Control-Allow-Headers` headers.
None,
/// Set the `Access-Control-Allow-Headers` header to the following header list. If empty, this
/// is treated as if it was [None].
List(Vec<HeaderName>),
/// Copy the `Access-Control-Request-Headers` header into the `Access-Control-Allow-Header`
/// header.
Copy
}
impl Default for Headers {
fn default() -> Self {
Self::None
}
}
impl Headers {
/// Get the header value for the `Access-Control-Allow-Headers` header.
fn header_value(&self, state: &State) -> Option<HeaderValue> {
match self {
Self::None => None,
Self::List(list) => Some(list.join(",").parse().unwrap()),
Self::Copy => {
let headers = HeaderMap::borrow_from(state);
headers.get(ACCESS_CONTROL_REQUEST_HEADERS).map(Clone::clone)
}
}
}
/// Returns true if the `Vary` header has to include `Origin`.
fn varies(&self) -> bool {
matches!(self, Self::Copy)
}
}
/**
@ -63,7 +108,7 @@ To change settings, you need to put this type into gotham's [State]:
```rust,no_run
# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State};
# use gotham_restful::*;
# use gotham_restful::{*, cors::Origin};
fn main() {
let cors = CorsConfig {
origin: Origin::Star,
@ -81,7 +126,7 @@ configurations for different scopes, you need to register the middleware inside
```rust,no_run
# use gotham::{router::builder::*, pipeline::*, pipeline::set::*, state::State};
# use gotham_restful::*;
# use gotham_restful::{*, cors::Origin};
let pipelines = new_pipeline_set();
// The first cors configuration
@ -119,7 +164,7 @@ pub struct CorsConfig {
/// The allowed origins.
pub origin: Origin,
/// The allowed headers.
pub headers: Vec<HeaderName>,
pub headers: Headers,
/// The amount of seconds that the preflight request can be cached.
pub max_age: u64,
/// Whether or not the request may be made with supplying credentials.
@ -149,22 +194,24 @@ For further information on CORS, read
*/
pub fn handle_cors(state: &State, res: &mut Response<Body>) {
let config = CorsConfig::try_borrow_from(state);
let headers = res.headers_mut();
if let Some(cfg) = config {
let headers = res.headers_mut();
// non-preflight requests require the Access-Control-Allow-Origin header
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state)) {
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
}
// non-preflight requests require the Access-Control-Allow-Origin header
if let Some(header) = cfg.origin.header_value(state) {
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
}
// if the origin is copied over, we should tell the browser by specifying the Vary header
if matches!(config.map(|cfg| &cfg.origin), Some(Origin::Copy)) {
let vary = headers.get(VARY).map(|vary| format!("{},Origin", vary.to_str().unwrap()));
headers.insert(VARY, vary.as_deref().unwrap_or("Origin").parse().unwrap());
}
// if the origin is copied over, we should tell the browser by specifying the Vary header
if cfg.origin.varies() {
let vary = headers.get(VARY).map(|vary| format!("{},origin", vary.to_str().unwrap()));
headers.insert(VARY, vary.as_deref().unwrap_or("origin").parse().unwrap());
}
// if we allow credentials, tell the browser
if config.map(|cfg| cfg.credentials).unwrap_or(false) {
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true".parse().unwrap());
// if we allow credentials, tell the browser
if cfg.credentials {
headers.insert(ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"));
}
}
}
@ -202,6 +249,7 @@ fn cors_preflight_handler(state: State) -> (State, Response<Body>) {
// prepare the response
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
let headers = res.headers_mut();
let mut vary: Vec<HeaderName> = Vec::new();
// copy the request method over to the response
let method = HeaderMap::borrow_from(&state)
@ -209,22 +257,27 @@ fn cors_preflight_handler(state: State) -> (State, Response<Body>) {
.unwrap()
.clone();
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
vary.push(ACCESS_CONTROL_REQUEST_METHOD);
// if we allow any headers, put them in
if let Some(hdrs) = config.map(|cfg| &cfg.headers) {
if hdrs.len() > 0 {
// TODO do we want to return all headers or just those asked by the browser?
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, hdrs.iter().join(",").parse().unwrap());
if let Some(cfg) = config {
// if we allow any headers, copy them over
if let Some(header) = cfg.headers.header_value(&state) {
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, header);
}
// if the headers are copied over, we should tell the browser by specifying the Vary header
if cfg.headers.varies() {
vary.push(ACCESS_CONTROL_REQUEST_HEADERS);
}
// set the max age for the preflight cache
if let Some(age) = config.map(|cfg| cfg.max_age) {
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
}
}
// set the max age for the preflight cache
if let Some(age) = config.map(|cfg| cfg.max_age) {
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
}
// make sure the browser knows that this request was based on the method
headers.insert(VARY, "Access-Control-Request-Method".parse().unwrap());
headers.insert(VARY, vary.join(",").parse().unwrap());
handle_cors(&state, &mut res);
(state, res)

View file

@ -196,7 +196,7 @@ authentication), and every content type, could look like this:
# #[cfg(feature = "cors")]
# mod cors_feature_enabled {
# use gotham::{hyper::header::*, router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State};
# use gotham_restful::*;
# use gotham_restful::{*, cors::*};
# use serde::{Deserialize, Serialize};
#[derive(Resource)]
#[resource(read_all)]
@ -210,7 +210,7 @@ fn read_all() {
fn main() {
let cors = CorsConfig {
origin: Origin::Copy,
headers: vec![CONTENT_TYPE],
headers: Headers::List(vec![CONTENT_TYPE]),
max_age: 0,
credentials: true
};
@ -417,9 +417,9 @@ mod auth;
pub use auth::{AuthHandler, AuthMiddleware, AuthSource, AuthStatus, AuthValidation, StaticAuthHandler};
#[cfg(feature = "cors")]
mod cors;
pub mod cors;
#[cfg(feature = "cors")]
pub use cors::{handle_cors, CorsConfig, CorsRoute, Origin};
pub use cors::{handle_cors, CorsConfig, CorsRoute};
#[cfg(feature = "openapi")]
mod openapi;