diff --git a/src/cors.rs b/src/cors.rs
index 558bedb..57e1a10 100644
--- a/src/cors.rs
+++ b/src/cors.rs
@@ -166,7 +166,7 @@ pub fn handle_cors(state : &State, res : &mut Response
)
let config = CorsConfig::try_borrow_from(state);
let headers = res.headers_mut();
- // non-preflight requests require nothing other than the Access-Control-Allow-Origin header
+ // non-preflight requests require the Access-Control-Allow-Origin header
if let Some(header) = config.and_then(|cfg| cfg.origin.header_value(state))
{
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
diff --git a/src/matcher/mod.rs b/src/matcher/mod.rs
index 3168ec3..9cbfcbb 100644
--- a/src/matcher/mod.rs
+++ b/src/matcher/mod.rs
@@ -10,6 +10,7 @@ pub use content_type::ContentTypeMatcher;
#[cfg(feature = "cors")]
mod access_control_request_method;
+#[cfg(feature = "cors")]
pub use access_control_request_method::AccessControlRequestMethodMatcher;
type LookupTable = HashMap>;
diff --git a/tests/cors_handling.rs b/tests/cors_handling.rs
new file mode 100644
index 0000000..a9fe498
--- /dev/null
+++ b/tests/cors_handling.rs
@@ -0,0 +1,156 @@
+#![cfg(feature = "cors")]
+use gotham::{
+ hyper::{body::Body, client::connect::Connect, header::*, StatusCode},
+ pipeline::{new_pipeline, single::single_pipeline},
+ router::builder::*,
+ test::{Server, TestRequest, TestServer}
+};
+use gotham_restful::{CorsConfig, DrawResources, Origin, Raw, Resource, change_all, read_all};
+use itertools::Itertools;
+use mime::TEXT_PLAIN;
+
+#[derive(Resource)]
+#[resource(read_all, change_all)]
+struct FooResource;
+
+#[read_all(FooResource)]
+fn read_all()
+{
+}
+
+#[change_all(FooResource)]
+fn change_all(_body : Raw>)
+{
+}
+
+fn test_server(cfg : CorsConfig) -> TestServer
+{
+ let (chain, pipeline) = single_pipeline(new_pipeline().add(cfg).build());
+ TestServer::new(build_router(chain, pipeline, |router| {
+ router.resource::("/foo")
+ })).unwrap()
+}
+
+fn test_response(req : TestRequest, origin : Option<&str>, vary : Option<&str>, credentials : bool)
+where
+ TS : Server + 'static,
+ C : Connect + Clone + Send + Sync + 'static
+{
+ let res = req.with_header(ORIGIN, "http://example.org".parse().unwrap()).perform().unwrap();
+ assert_eq!(res.status(), StatusCode::NO_CONTENT);
+ let headers = res.headers();
+ println!("{}", headers.keys().join(","));
+ assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin);
+ assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), vary);
+ assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials);
+ assert!(headers.get(ACCESS_CONTROL_MAX_AGE).is_none());
+}
+
+fn test_preflight(server : &TestServer, method : &str, origin : Option<&str>, vary : &str, credentials : bool, max_age : u64)
+{
+ let res = server.client().options("http://example.org/foo")
+ .with_header(ACCESS_CONTROL_REQUEST_METHOD, method.parse().unwrap())
+ .with_header(ORIGIN, "http://example.org".parse().unwrap())
+ .perform().unwrap();
+ assert_eq!(res.status(), StatusCode::NO_CONTENT);
+ let headers = res.headers();
+ println!("{}", headers.keys().join(","));
+ assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).and_then(|value| value.to_str().ok()).as_deref(), Some(method));
+ assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_ORIGIN).and_then(|value| value.to_str().ok()).as_deref(), origin);
+ assert_eq!(headers.get(VARY).and_then(|value| value.to_str().ok()).as_deref(), Some(vary));
+ assert_eq!(headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).and_then(|value| value.to_str().ok()).map(|value| value == "true").unwrap_or(false), credentials);
+ assert_eq!(headers.get(ACCESS_CONTROL_MAX_AGE).and_then(|value| value.to_str().ok()).and_then(|value| value.parse().ok()), Some(max_age));
+}
+
+
+#[test]
+fn cors_origin_none()
+{
+ let cfg = CorsConfig {
+ origin: Origin::None,
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 0);
+
+ test_response(server.client().get("http://example.org/foo"), None, None, false);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false);
+}
+
+#[test]
+fn cors_origin_star()
+{
+ let cfg = CorsConfig {
+ origin: Origin::Star,
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", Some("*"), "Access-Control-Request-Method", false, 0);
+
+ test_response(server.client().get("http://example.org/foo"), Some("*"), None, false);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("*"), None, false);
+}
+
+#[test]
+fn cors_origin_single()
+{
+ let cfg = CorsConfig {
+ origin: Origin::Single("https://foo.com".to_owned()),
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", Some("https://foo.com"), "Access-Control-Request-Method", false, 0);
+
+ test_response(server.client().get("http://example.org/foo"), Some("https://foo.com"), None, false);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("https://foo.com"), None, false);
+}
+
+#[test]
+fn cors_origin_copy()
+{
+ let cfg = CorsConfig {
+ origin: Origin::Copy,
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", Some("http://example.org"), "Access-Control-Request-Method,Origin", false, 0);
+
+ test_response(server.client().get("http://example.org/foo"), Some("http://example.org"), Some("Origin"), false);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), Some("http://example.org"), Some("Origin"), false);
+}
+
+#[test]
+fn cors_credentials()
+{
+ let cfg = CorsConfig {
+ origin: Origin::None,
+ credentials: true,
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", None, "Access-Control-Request-Method", true, 0);
+
+ test_response(server.client().get("http://example.org/foo"), None, None, true);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, true);
+}
+
+#[test]
+fn cors_max_age()
+{
+ let cfg = CorsConfig {
+ origin: Origin::None,
+ max_age: 31536000,
+ ..Default::default()
+ };
+ let server = test_server(cfg);
+
+ test_preflight(&server, "PUT", None, "Access-Control-Request-Method", false, 31536000);
+
+ test_response(server.client().get("http://example.org/foo"), None, None, false);
+ test_response(server.client().put("http://example.org/foo", Body::empty(), TEXT_PLAIN), None, None, false);
+}
diff --git a/tests/openapi_supports_scope.rs b/tests/openapi_supports_scope.rs
index 62228da..3b9aa2c 100644
--- a/tests/openapi_supports_scope.rs
+++ b/tests/openapi_supports_scope.rs
@@ -1,8 +1,4 @@
-#[cfg(feature = "openapi")]
-mod openapi_supports_scope
-{
-
-
+#![cfg(feature = "openapi")]
use gotham::{
router::builder::*,
test::TestServer
@@ -29,7 +25,7 @@ fn read_all() -> Raw<&'static [u8]>
#[test]
-fn test()
+fn openapi_supports_scope()
{
let info = OpenapiInfo {
title: "Test".to_owned(),
@@ -54,6 +50,3 @@ fn test()
test_get_response(&server, "http://localhost/bar/baz/foo3", RESPONSE);
test_get_response(&server, "http://localhost/foo4", RESPONSE);
}
-
-
-} // mod test