From 088774fc50eb54cc5e66d70435c9bcbd5af26e83 Mon Sep 17 00:00:00 2001 From: msrd0 Date: Wed, 22 Jan 2020 16:53:02 +0000 Subject: [PATCH] implement auth parsing/verifying inside a gotham middleware --- README.md | 3 +- example/Cargo.toml | 2 +- example/src/main.rs | 25 ++ gotham_restful/Cargo.toml | 4 + gotham_restful/src/auth.rs | 452 +++++++++++++++++++++++++++ gotham_restful/src/lib.rs | 17 +- gotham_restful/src/openapi/router.rs | 67 +++- gotham_restful_derive/Cargo.toml | 1 + gotham_restful_derive/src/method.rs | 199 +++++++++--- 9 files changed, 721 insertions(+), 49 deletions(-) create mode 100644 gotham_restful/src/auth.rs diff --git a/README.md b/README.md index 1c7ebe9..4dbd9c6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ resources. ## Usage -To use this crate, add the following to your `Cargo.toml`: +This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the +following to your `Cargo.toml`: ```toml [dependencies] diff --git a/example/Cargo.toml b/example/Cargo.toml index 58ee7ba..7875c5c 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -17,7 +17,7 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" } fake = "2.2" gotham = "0.4" gotham_derive = "0.4" -gotham_restful = { version = "0.0.1", features = ["openapi"] } +gotham_restful = { version = "0.0.1", features = ["auth", "openapi"] } hyper = "0.12" log = "0.4" log4rs = { version = "0.8", features = ["console_appender"], default-features = false } diff --git a/example/src/main.rs b/example/src/main.rs index 69b99b3..ea16d0f 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -23,6 +23,12 @@ struct Users { } +#[derive(Resource)] +#[rest_resource(ReadAll)] +struct Auth +{ +} + #[derive(Deserialize, OpenapiType, Serialize, StateData, StaticResponseExtender)] struct User { @@ -82,8 +88,24 @@ fn delete(_state : &mut State, id : u64) info!("Delete User {}", id); } +#[rest_read_all(Auth)] +fn auth_read_all(auth : AuthStatus<()>) -> Success +{ + format!("{:?}", auth).into() +} + const ADDR : &str = "127.0.0.1:18080"; +#[derive(Clone, Default)] +struct Handler; +impl AuthHandler for Handler +{ + fn jwt_secret Option>(&self, _state : &mut State, _decode_data : F) -> Option> + { + None + } +} + fn main() { let encoder = PatternEncoder::new("{d(%Y-%m-%d %H:%M:%S%.3f %Z)} [{l}] {M} - {m}\n"); @@ -99,9 +121,11 @@ fn main() .unwrap(); log4rs::init_config(config).unwrap(); + let auth = >::from_source(AuthSource::AuthorizationHeader); let logging = RequestLogger::new(log::Level::Info); let (chain, pipelines) = single_pipeline( new_pipeline() + .add(auth) .add(logging) .build() ); @@ -109,6 +133,7 @@ fn main() gotham::start(ADDR, build_router(chain, pipelines, |route| { route.with_openapi("Users Example", "0.0.1", format!("http://{}", ADDR), |mut route| { route.resource::("users"); + route.resource::("auth"); route.get_openapi("openapi"); }); })); diff --git a/gotham_restful/Cargo.toml b/gotham_restful/Cargo.toml index 1b9114b..43f2d06 100644 --- a/gotham_restful/Cargo.toml +++ b/gotham_restful/Cargo.toml @@ -16,7 +16,9 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" } codecov = { repository = "msrd0/gotham-restful", branch = "master", service = "gitlab" } [dependencies] +base64 = { version = "0.11", optional = true } chrono = { version = "0.4.9", optional = true } +cookie = { version = "0.12", optional = true } failure = "0.1.6" futures = "0.1.29" gotham = "0.4" @@ -25,6 +27,7 @@ gotham_middleware_diesel = { version = "0.1", optional = true } gotham_restful_derive = { version = "0.0.1" } hyper = "0.12.35" indexmap = { version = "1.3.0", optional = true } +jsonwebtoken = { version = "6.0.1", optional = true } log = { version = "0.4.8", optional = true } mime = "0.3.14" openapiv3 = { version = "0.3", optional = true } @@ -37,5 +40,6 @@ thiserror = "1" [features] default = [] +auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"] database = ["gotham_restful_derive/database", "gotham_middleware_diesel"] openapi = ["gotham_restful_derive/openapi", "indexmap", "log", "openapiv3"] diff --git a/gotham_restful/src/auth.rs b/gotham_restful/src/auth.rs new file mode 100644 index 0000000..323a344 --- /dev/null +++ b/gotham_restful/src/auth.rs @@ -0,0 +1,452 @@ +use crate::HeaderName; +use cookie::CookieJar; +use futures::{future, future::Future}; +use gotham::{ + handler::HandlerFuture, + middleware::{Middleware, NewMiddleware}, + state::{FromState, State} +}; +use hyper::header::{AUTHORIZATION, HeaderMap}; +use jsonwebtoken::errors::ErrorKind; +use serde::de::DeserializeOwned; +use std::{ + marker::PhantomData, + panic::RefUnwindSafe +}; + +pub use jsonwebtoken::Validation as AuthValidation; + +/// The authentication status returned by the auth middleware for each request. +#[derive(Debug, StateData)] +pub enum AuthStatus +{ + /// The auth status is unknown. + Unknown, + /// The request has been performed without any kind of authentication. + Unauthenticated, + /// The request has been performed with an invalid authentication. + Invalid, + /// The request has been performed with an expired authentication. + Expired, + /// The request has been performed with a valid authentication. + Authenticated(T) +} + +impl Clone for AuthStatus +where + T : Clone + Send + 'static +{ + fn clone(&self) -> Self + { + match self { + Self::Unknown => Self::Unknown, + Self::Unauthenticated => Self::Unauthenticated, + Self::Invalid => Self::Invalid, + Self::Expired => Self::Expired, + Self::Authenticated(data) => Self::Authenticated(data.clone()) + } + } +} + +/// The source of the authentication token in the request. +#[derive(Clone, StateData)] +pub enum AuthSource +{ + /// Take the token from a cookie with the given name. + Cookie(String), + /// Take the token from a header with the given name. + Header(HeaderName), + /// Take the token from the HTTP Authorization header. This is different from `Header("Authorization")` + /// as it will follow the `scheme param` format from the HTTP specification. The `scheme` will + /// be discarded, so its value doesn't matter. + AuthorizationHeader +} + +/** +This trait will help the auth middleware to determine the validity of an authentication token. + +A very basic implementation could look like this: +``` +# use gotham_restful::{export::State, AuthHandler}; +# +const SECRET : &'static [u8; 32] = b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc"; + +struct CustomAuthHandler; +impl AuthHandler for CustomAuthHandler { + fn jwt_secret Option>(&self, _state : &mut State, _decode_data : F) -> Option> { + Some(SECRET.to_vec()) + } +} +``` +*/ +pub trait AuthHandler +{ + /// Return the SHA256-HMAC secret used to verify the JWT token. + fn jwt_secret Option>(&self, state : &mut State, decode_data : F) -> Option>; +} + +/// An `AuthHandler` returning always the same secret. See `AuthMiddleware` for a usage example. +#[derive(Clone, Debug)] +pub struct StaticAuthHandler +{ + secret : Vec +} + +impl StaticAuthHandler +{ + pub fn from_vec(secret : Vec) -> Self + { + Self { secret } + } + + pub fn from_array(secret : &[u8]) -> Self + { + Self::from_vec(secret.to_vec()) + } +} + +impl AuthHandler for StaticAuthHandler +{ + fn jwt_secret Option>(&self, _state : &mut State, _decode_data : F) -> Option> + { + Some(self.secret.clone()) + } +} + +/** +This is the auth middleware. To use it, first make sure you have the `auth` feature enabled. Then +simply add it to your pipeline and request it inside your handler: + +```rust,no_run +# #[macro_use] extern crate gotham_restful_derive; +# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State}; +# use gotham_restful::*; +# use serde::{Deserialize, Serialize}; +# +#[derive(Resource)] +#[rest_resource(read_all)] +struct AuthResource; + +#[derive(Debug, Deserialize)] +struct AuthData { + sub: String, + exp: u64 +} + +#[rest_read_all(AuthResource)] +fn read_all(auth : &AuthStatus) -> Success { + format!("{:?}", auth).into() +} + +fn main() { + let auth : AuthMiddleware = AuthMiddleware::new( + AuthSource::AuthorizationHeader, + AuthValidation::default(), + StaticAuthHandler::from_array(b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc") + ); + let (chain, pipelines) = single_pipeline(new_pipeline().add(auth).build()); + gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| { + route.resource::("auth"); + })); +} +``` +*/ +pub struct AuthMiddleware +{ + source : AuthSource, + validation : AuthValidation, + handler : Handler, + _data : PhantomData +} + +impl Clone for AuthMiddleware +where Handler : Clone +{ + fn clone(&self) -> Self + { + Self { + source: self.source.clone(), + validation: self.validation.clone(), + handler: self.handler.clone(), + _data: self._data + } + } +} + +impl AuthMiddleware +where + Data : DeserializeOwned + Send, + Handler : AuthHandler + Default +{ + pub fn from_source(source : AuthSource) -> Self + { + Self { + source, + validation: Default::default(), + handler: Default::default(), + _data: Default::default() + } + } +} + +impl AuthMiddleware +where + Data : DeserializeOwned + Send, + Handler : AuthHandler +{ + pub fn new(source : AuthSource, validation : AuthValidation, handler : Handler) -> Self + { + Self { + source, + validation, + handler, + _data: Default::default() + } + } + + fn auth_status(&self, state : &mut State) -> AuthStatus + { + // extract the provided token, if any + let token = match &self.source { + AuthSource::Cookie(name) => { + CookieJar::try_borrow_from(&state) + .and_then(|jar| jar.get(&name)) + .map(|cookie| cookie.value().to_owned()) + }, + AuthSource::Header(name) => { + HeaderMap::try_borrow_from(&state) + .and_then(|map| map.get(name)) + .and_then(|header| header.to_str().ok()) + .map(|value| value.to_owned()) + }, + AuthSource::AuthorizationHeader => { + HeaderMap::try_borrow_from(&state) + .and_then(|map| map.get(AUTHORIZATION)) + .and_then(|header| header.to_str().ok()) + .and_then(|value| value.split_whitespace().nth(1)) + .map(|value| value.to_owned()) + } + }; + + // unauthed if no token + let token = match token { + Some(token) => token, + None => return AuthStatus::Unauthenticated + }; + + // get the secret from the handler, possibly decoding claims ourselves + let secret = self.handler.jwt_secret(state, || { + let b64 = token.split(".").nth(1)?; + let raw = base64::decode_config(b64, base64::URL_SAFE_NO_PAD).ok()?; + serde_json::from_slice(&raw).ok()? + }); + + // unknown if no secret + let secret = match secret { + Some(secret) => secret, + None => return AuthStatus::Unknown + }; + + // validate the token + let data : Data = match jsonwebtoken::decode(&token, &secret, &self.validation) { + Ok(data) => data.claims, + Err(e) => match dbg!(e.into_kind()) { + ErrorKind::ExpiredSignature => return AuthStatus::Expired, + _ => return AuthStatus::Invalid + } + }; + + // we found a valid token + return AuthStatus::Authenticated(data); + } +} + +impl Middleware for AuthMiddleware +where + Data : DeserializeOwned + Send + 'static, + Handler : AuthHandler +{ + fn call(self, mut state : State, chain : Chain) -> Box + where + Chain : FnOnce(State) -> Box + { + // put the source in our state, required for e.g. openapi + state.put(self.source.clone()); + + // put the status in our state + let status = self.auth_status(&mut state); + state.put(status); + + // call the rest of the chain + Box::new(chain(state).and_then(|(state, res)| future::ok((state, res)))) + } +} + +impl NewMiddleware for AuthMiddleware +where + Self : Clone + Middleware + Sync + RefUnwindSafe +{ + type Instance = Self; + + fn new_middleware(&self) -> Result + { + let c : Self = self.clone(); + Ok(c) + } +} + +#[cfg(test)] +mod test +{ + use super::*; + use cookie::Cookie; + use std::fmt::Debug; + + // 256-bit random string + const JWT_SECRET : &'static [u8; 32] = b"Lyzsfnta0cdxyF0T9y6VGxp3jpgoMUuW"; + + // some known tokens + const VALID_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjQxMDI0NDQ4MDB9.8h8Ax-nnykqEQ62t7CxmM3ja6NzUQ4L0MLOOzddjLKk"; + const EXPIRED_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjE1Nzc4MzcxMDB9.eV1snaGLYrJ7qUoMk74OvBY3WUU9M0Je5HTU2xtX1v0"; + + #[derive(Debug, Deserialize, PartialEq)] + struct TestData + { + iss : String, + sub : String, + iat : u64, + exp : u64 + } + + impl Default for TestData + { + fn default() -> Self + { + Self { + iss: "msrd0".to_owned(), + sub: "gotham-restful".to_owned(), + iat: 1577836800, + exp: 4102444800 + } + } + } + + #[derive(Default)] + struct TestHandler; + impl AuthHandler for TestHandler + { + fn jwt_secret Option>(&self, _state : &mut State, _decode_data : F) -> Option> + { + Some(JWT_SECRET.to_vec()) + } + } + + #[derive(Default)] + struct TestAssertingHandler; + impl AuthHandler for TestAssertingHandler + where T : Debug + Default + PartialEq + { + fn jwt_secret Option>(&self, _state : &mut State, decode_data : F) -> Option> + { + assert_eq!(decode_data(), Some(T::default())); + Some(JWT_SECRET.to_vec()) + } + } + + #[test] + fn test_auth_middleware_decode_data() + { + let middleware = >::from_source(AuthSource::AuthorizationHeader); + State::with_new(|mut state| { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap()); + state.put(headers); + middleware.auth_status(&mut state); + }); + } + + fn new_middleware(source : AuthSource) -> AuthMiddleware + where T : DeserializeOwned + Send + { + AuthMiddleware::from_source(source) + } + + #[test] + fn test_auth_middleware_no_token() + { + let middleware = new_middleware::(AuthSource::AuthorizationHeader); + State::with_new(|mut state| { + let status = middleware.auth_status(&mut state); + match status { + AuthStatus::Unauthenticated => {}, + _ => panic!("Expected AuthStatus::Unauthenticated, got {:?}", status) + }; + }); + } + + #[test] + fn test_auth_middleware_expired_token() + { + let middleware = new_middleware::(AuthSource::AuthorizationHeader); + State::with_new(|mut state| { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Bearer {}", EXPIRED_TOKEN).parse().unwrap()); + state.put(headers); + let status = middleware.auth_status(&mut state); + match status { + AuthStatus::Expired => {}, + _ => panic!("Expected AuthStatus::Expired, got {:?}", status) + }; + }); + } + + #[test] + fn test_auth_middleware_auth_header_token() + { + let middleware = new_middleware::(AuthSource::AuthorizationHeader); + State::with_new(|mut state| { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap()); + state.put(headers); + let status = middleware.auth_status(&mut state); + match status { + AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()), + _ => panic!("Expected AuthStatus::Authenticated, got {:?}", status) + }; + }) + } + + #[test] + fn test_auth_middleware_header_token() + { + let header_name = "x-znoiprwmvfexju"; + let middleware = new_middleware::(AuthSource::Header(HeaderName::from_static(header_name))); + State::with_new(|mut state| { + let mut headers = HeaderMap::new(); + headers.insert(header_name, VALID_TOKEN.parse().unwrap()); + state.put(headers); + let status = middleware.auth_status(&mut state); + match status { + AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()), + _ => panic!("Expected AuthStatus::Authenticated, got {:?}", status) + }; + }) + } + + #[test] + fn test_auth_middleware_cookie_token() + { + let cookie_name = "znoiprwmvfexju"; + let middleware = new_middleware::(AuthSource::Cookie(cookie_name.to_owned())); + State::with_new(|mut state| { + let mut jar = CookieJar::new(); + jar.add_original(Cookie::new(cookie_name, VALID_TOKEN)); + state.put(jar); + let status = middleware.auth_status(&mut state); + match status { + AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()), + _ => panic!("Expected AuthStatus::Authenticated, got {:?}", status) + }; + }) + } +} diff --git a/gotham_restful/src/lib.rs b/gotham_restful/src/lib.rs index f330536..b46e7bd 100644 --- a/gotham_restful/src/lib.rs +++ b/gotham_restful/src/lib.rs @@ -8,7 +8,8 @@ resources. # Usage -To use this crate, add the following to your `Cargo.toml`: +This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the +following to your `Cargo.toml`: ```toml [dependencies] @@ -112,7 +113,7 @@ extern crate self as gotham_restful; #[macro_use] extern crate serde; #[doc(no_inline)] -pub use hyper::{Chunk, StatusCode}; +pub use hyper::{header::HeaderName, Chunk, StatusCode}; #[doc(no_inline)] pub use mime::Mime; @@ -134,6 +135,18 @@ pub mod export pub use openapiv3 as openapi; } +#[cfg(feature = "auth")] +mod auth; +#[cfg(feature = "auth")] +pub use auth::{ + AuthHandler, + AuthMiddleware, + AuthSource, + AuthStatus, + AuthValidation, + StaticAuthHandler +}; + #[cfg(feature = "openapi")] mod openapi; #[cfg(feature = "openapi")] diff --git a/gotham_restful/src/openapi/router.rs b/gotham_restful/src/openapi/router.rs index d1d8dc0..a2b9077 100644 --- a/gotham_restful/src/openapi/router.rs +++ b/gotham_restful/src/openapi/router.rs @@ -21,9 +21,9 @@ use indexmap::IndexMap; use log::error; use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN}; use openapiv3::{ - Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, + APIKeyLocation, Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, Paths, ReferenceOr, ReferenceOr::Item, ReferenceOr::Reference, RequestBody as OARequestBody, Response, Responses, Schema, - SchemaKind, Server, StatusCode, Type + SchemaKind, SecurityRequirement, SecurityScheme, Server, StatusCode, Type }; use serde::de::DeserializeOwned; use std::panic::RefUnwindSafe; @@ -125,16 +125,13 @@ impl OpenapiRouter } #[derive(Clone)] -struct OpenapiHandler(Result); - -// dunno what/why/whatever -impl RefUnwindSafe for OpenapiHandler {} +struct OpenapiHandler(OpenAPI); impl OpenapiHandler { fn new(openapi : &OpenapiRouter) -> Self { - Self(serde_json::to_string(&openapi.0).map_err(|e| format!("{}", e))) + Self(openapi.0.clone()) } } @@ -148,11 +145,63 @@ impl NewHandler for OpenapiHandler } } +#[cfg(feature = "auth")] +const SECURITY_NAME : &'static str = "authToken"; + +#[cfg(feature = "auth")] +fn get_security(state : &mut State) -> (Vec, IndexMap>) +{ + use crate::AuthSource; + use gotham::state::FromState; + + let source = match AuthSource::try_borrow_from(state) { + Some(source) => source, + None => return Default::default() + }; + + let mut security : IndexMap> = Default::default(); + security.insert(SECURITY_NAME.to_owned(), Vec::new()); + let security = vec![security]; + + let security_scheme = match source { + AuthSource::Cookie(name) => SecurityScheme::APIKey { + location: APIKeyLocation::Cookie, + name: name.to_string() + }, + AuthSource::Header(name) => SecurityScheme::APIKey { + location: APIKeyLocation::Header, + name: name.to_string() + }, + AuthSource::AuthorizationHeader => SecurityScheme::HTTP { + scheme: "bearer".to_owned(), + bearer_format: Some("JWT".to_owned()) + } + }; + + let mut security_schemes : IndexMap> = Default::default(); + security_schemes.insert(SECURITY_NAME.to_owned(), ReferenceOr::Item(security_scheme)); + + (security, security_schemes) +} + +#[cfg(not(feature = "auth"))] +fn get_security(state : &mut State) -> (Vec, IndexMap>) +{ + Default::default() +} + impl Handler for OpenapiHandler { - fn handle(self, state : State) -> Box + fn handle(self, mut state : State) -> Box { - match self.0 { + let mut openapi = self.0; + let (security, security_schemes) = get_security(&mut state); + openapi.security = security; + let mut components = openapi.components.unwrap_or_default(); + components.security_schemes = security_schemes; + openapi.components = Some(components); + + match serde_json::to_string(&openapi) { Ok(body) => { let res = create_response(&state, hyper::StatusCode::OK, APPLICATION_JSON, body); Box::new(ok((state, res))) diff --git a/gotham_restful_derive/Cargo.toml b/gotham_restful_derive/Cargo.toml index 6f07807..09e505b 100644 --- a/gotham_restful_derive/Cargo.toml +++ b/gotham_restful_derive/Cargo.toml @@ -25,5 +25,6 @@ syn = { version = "1.0.13", features = ["extra-traits", "full"] } [features] default = [] +auth = [] database = [] openapi = [] diff --git a/gotham_restful_derive/src/method.rs b/gotham_restful_derive/src/method.rs index 6f4c461..e8c0647 100644 --- a/gotham_restful_derive/src/method.rs +++ b/gotham_restful_derive/src/method.rs @@ -3,8 +3,10 @@ use proc_macro::TokenStream; use proc_macro2::{Ident, TokenStream as TokenStream2}; use quote::{format_ident, quote}; use syn::{ + Attribute, FnArg, ItemFn, + PatType, ReturnType, Type, parse_macro_input @@ -84,6 +86,112 @@ impl Method } } +enum MethodArgumentType +{ + StateRef, + StateMutRef, + MethodArg(Type), + DatabaseConnection(Type), + AuthStatus(Type), + AuthStatusRef(Type) +} + +impl MethodArgumentType +{ + fn is_method_arg(&self) -> bool + { + match self { + Self::MethodArg(_) => true, + _ => false, + } + } + + fn is_database_conn(&self) -> bool + { + match self { + Self::DatabaseConnection(_) => true, + _ => false + } + } + + fn is_auth_status(&self) -> bool + { + match self { + Self::AuthStatus(_) | Self::AuthStatusRef(_) => true, + _ => false + } + } + + fn is_auth_status_ref(&self) -> bool + { + match self { + Self::AuthStatusRef(_) => true, + _ => false + } + } + + fn quote_ty(&self) -> Option + { + match self { + Self::MethodArg(ty) => Some(quote!(#ty)), + Self::DatabaseConnection(ty) => Some(quote!(#ty)), + Self::AuthStatus(ty) => Some(quote!(#ty)), + Self::AuthStatusRef(ty) => Some(quote!(#ty)), + _ => None + } + } +} + +struct MethodArgument +{ + ident : Ident, + ty : MethodArgumentType +} + +fn interpret_arg_ty(index : usize, attrs : &[Attribute], name : &str, ty : Type) -> MethodArgumentType +{ + let attr = attrs.into_iter() + .filter(|arg| arg.path.segments.iter().filter(|path| &path.ident.to_string() == "rest_arg").nth(0).is_some()) + .nth(0) + .map(|arg| arg.tokens.to_string()); + + if cfg!(feature = "auth") && (attr.as_deref() == Some("auth") || (attr.is_none() && name == "auth")) + { + return match ty { + Type::Reference(ty) => MethodArgumentType::AuthStatusRef(*ty.elem), + ty => MethodArgumentType::AuthStatus(ty) + }; + } + + if cfg!(feature = "database") && (attr.as_deref() == Some("connection") || attr.as_deref() == Some("conn") || (attr.is_none() && name == "conn")) + { + return MethodArgumentType::DatabaseConnection(match ty { + Type::Reference(ty) => *ty.elem, + ty => ty + }); + } + + if index == 0 + { + return match ty { + Type::Reference(ty) => if ty.mutability.is_none() { MethodArgumentType::StateRef } else { MethodArgumentType::StateMutRef }, + _ => panic!("The first argument, unless some feature is used, has to be a (mutable) reference to gotham::state::State") + }; + } + + MethodArgumentType::MethodArg(ty) +} + +fn interpret_arg(index : usize, arg : &PatType) -> MethodArgument +{ + let pat = &arg.pat; + let ident = format_ident!("arg{}", index); + let orig_name = quote!(#pat); + let ty = interpret_arg_ty(index, &arg.attrs, &orig_name.to_string(), *arg.ty.clone()); + + MethodArgument { ident, ty } +} + pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -> TokenStream { let krate = super::krate(); @@ -101,71 +209,90 @@ pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) - ReturnType::Type(_, ty) => (quote!(#ty), false) }; - // extract arguments into pattern, ident and type + // some default idents we'll need let state_ident = format_ident!("state"); - let args : Vec<(usize, TokenStream2, Ident, Type)> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg { - FnArg::Typed(arg) => { - let pat = &arg.pat; - let ident = format_ident!("arg{}", i); - (i, quote!(#pat), ident, *arg.ty.clone()) - }, + let repo_ident = format_ident!("repo"); + let conn_ident = format_ident!("conn"); + let auth_ident = format_ident!("auth"); + + // extract arguments into pattern, ident and type + let args : Vec = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg { + FnArg::Typed(arg) => interpret_arg(i, arg), FnArg::Receiver(_) => panic!("didn't expect self parameter") }).collect(); - // find the database connection if enabled and present - let repo_ident = format_ident!("database_repo"); - let args_conn = if cfg!(feature = "database") { - args.iter().filter(|(_, pat, _, _)| pat.to_string() == "conn").nth(0) - } else { None }; - let args_conn_name = args_conn.map(|(_, pat, _, _)| pat.to_string()); - // extract the generic parameters to use - let mut generics : Vec = args.iter().skip(1) - .filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name) - .map(|(_, _, _, ty)| quote!(#ty)).collect(); + let mut generics : Vec = args.iter() + .filter(|arg| (*arg).ty.is_method_arg()) + .map(|arg| arg.ty.quote_ty().unwrap()) + .collect(); generics.push(quote!(#ret)); // extract the definition of our method let mut args_def : Vec = args.iter() - .filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name) - .map(|(i, _, ident, ty)| if *i == 0 { quote!(#state_ident : #ty) } else { quote!(#ident : #ty) }).collect(); - if let Some(_) = args_conn - { - args_def.insert(0, quote!(#state_ident : &mut #krate::export::State)); - } + .filter(|arg| (*arg).ty.is_method_arg()) + .map(|arg| { + let ident = &arg.ident; + let ty = arg.ty.quote_ty(); + quote!(#ident : #ty) + }).collect(); + args_def.insert(0, quote!(#state_ident : &mut #krate::export::State)); // extract the arguments to pass over to the supplied method - let args_pass : Vec = args.iter().map(|(i, pat, ident, _)| if Some(pat.to_string()) != args_conn_name { - if *i == 0 { quote!(#state_ident) } else { quote!(#ident) } - } else { - quote!(&#ident) + let args_pass : Vec = args.iter().map(|arg| match (&arg.ty, &arg.ident) { + (MethodArgumentType::StateRef, _) => quote!(#state_ident), + (MethodArgumentType::StateMutRef, _) => quote!(#state_ident), + (MethodArgumentType::MethodArg(_), ident) => quote!(#ident), + (MethodArgumentType::DatabaseConnection(_), _) => quote!(&#conn_ident), + (MethodArgumentType::AuthStatus(_), _) => quote!(#auth_ident.clone()), + (MethodArgumentType::AuthStatusRef(_), _) => quote!(#auth_ident) }).collect(); // prepare the method block - let mut block = if is_no_content { quote!(#fun_ident(#(#args_pass),*); Default::default()) } else { quote!(#fun_ident(#(#args_pass),*)) }; - if /*cfg!(feature = "database") &&*/ let Some((_, _, conn_ident, conn_ty)) = args_conn // https://github.com/rust-lang/rust/issues/53667 + let mut block = quote!(#fun_ident(#(#args_pass),*)); + if is_no_content { - let conn_ty_real = match conn_ty { - Type::Reference(ty) => &*ty.elem, - ty => ty - }; + block = quote!(#block; Default::default()) + } + if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_database_conn()).nth(0) + { + let conn_ty = arg.ty.quote_ty(); block = quote! { - use #krate::export::{Future, FromState}; - let #repo_ident = <#krate::export::Repo<#conn_ty_real>>::borrow_from(&#state_ident).clone(); + let #repo_ident = <#krate::export::Repo<#conn_ty>>::borrow_from(&#state_ident).clone(); #repo_ident.run::<_, #ret, ()>(move |#conn_ident| { Ok({#block}) }).wait().unwrap() }; } + if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_auth_status()).nth(0) + { + let auth_ty = arg.ty.quote_ty(); + block = quote! { + let #auth_ident : &#auth_ty = <#auth_ty>::borrow_from(#state_ident); + #block + }; + } + // prepare the where clause + let mut where_clause = quote!(#resource_ident : #krate::Resource,); + for arg in args.iter().filter(|arg| (*arg).ty.is_auth_status() && !(*arg).ty.is_auth_status_ref()) + { + let auth_ty = arg.ty.quote_ty(); + where_clause = quote!(#where_clause #auth_ty : Clone,); + } + + // put everything together let output = quote! { #fun impl #krate::#trait_ident<#(#generics),*> for #resource_ident - where #resource_ident : #krate::Resource + where #where_clause { fn #method_ident(#(#args_def),*) -> #ret { + #[allow(unused_imports)] + use #krate::export::{Future, FromState}; + #block } }