1
0
Fork 0
mirror of https://gitlab.com/msrd0/gotham-restful.git synced 2025-02-23 04:52:28 +00:00

Merge branch 'auth' into 'master'

See merge request msrd0/gotham-restful!4
This commit is contained in:
msrd0 2020-01-22 16:53:02 +00:00
commit 747c0063c4
9 changed files with 721 additions and 49 deletions

View file

@ -12,7 +12,8 @@ resources.
## Usage ## Usage
To use this crate, add the following to your `Cargo.toml`: This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the
following to your `Cargo.toml`:
```toml ```toml
[dependencies] [dependencies]

View file

@ -17,7 +17,7 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" }
fake = "2.2" fake = "2.2"
gotham = "0.4" gotham = "0.4"
gotham_derive = "0.4" gotham_derive = "0.4"
gotham_restful = { version = "0.0.1", features = ["openapi"] } gotham_restful = { version = "0.0.1", features = ["auth", "openapi"] }
hyper = "0.12" hyper = "0.12"
log = "0.4" log = "0.4"
log4rs = { version = "0.8", features = ["console_appender"], default-features = false } log4rs = { version = "0.8", features = ["console_appender"], default-features = false }

View file

@ -23,6 +23,12 @@ struct Users
{ {
} }
#[derive(Resource)]
#[rest_resource(ReadAll)]
struct Auth
{
}
#[derive(Deserialize, OpenapiType, Serialize, StateData, StaticResponseExtender)] #[derive(Deserialize, OpenapiType, Serialize, StateData, StaticResponseExtender)]
struct User struct User
{ {
@ -82,8 +88,24 @@ fn delete(_state : &mut State, id : u64)
info!("Delete User {}", id); info!("Delete User {}", id);
} }
#[rest_read_all(Auth)]
fn auth_read_all(auth : AuthStatus<()>) -> Success<String>
{
format!("{:?}", auth).into()
}
const ADDR : &str = "127.0.0.1:18080"; const ADDR : &str = "127.0.0.1:18080";
#[derive(Clone, Default)]
struct Handler;
impl<T> AuthHandler<T> for Handler
{
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
{
None
}
}
fn main() fn main()
{ {
let encoder = PatternEncoder::new("{d(%Y-%m-%d %H:%M:%S%.3f %Z)} [{l}] {M} - {m}\n"); let encoder = PatternEncoder::new("{d(%Y-%m-%d %H:%M:%S%.3f %Z)} [{l}] {M} - {m}\n");
@ -99,9 +121,11 @@ fn main()
.unwrap(); .unwrap();
log4rs::init_config(config).unwrap(); log4rs::init_config(config).unwrap();
let auth = <AuthMiddleware<(), Handler>>::from_source(AuthSource::AuthorizationHeader);
let logging = RequestLogger::new(log::Level::Info); let logging = RequestLogger::new(log::Level::Info);
let (chain, pipelines) = single_pipeline( let (chain, pipelines) = single_pipeline(
new_pipeline() new_pipeline()
.add(auth)
.add(logging) .add(logging)
.build() .build()
); );
@ -109,6 +133,7 @@ fn main()
gotham::start(ADDR, build_router(chain, pipelines, |route| { gotham::start(ADDR, build_router(chain, pipelines, |route| {
route.with_openapi("Users Example", "0.0.1", format!("http://{}", ADDR), |mut route| { route.with_openapi("Users Example", "0.0.1", format!("http://{}", ADDR), |mut route| {
route.resource::<Users, _>("users"); route.resource::<Users, _>("users");
route.resource::<Auth, _>("auth");
route.get_openapi("openapi"); route.get_openapi("openapi");
}); });
})); }));

View file

@ -16,7 +16,9 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" }
codecov = { repository = "msrd0/gotham-restful", branch = "master", service = "gitlab" } codecov = { repository = "msrd0/gotham-restful", branch = "master", service = "gitlab" }
[dependencies] [dependencies]
base64 = { version = "0.11", optional = true }
chrono = { version = "0.4.9", optional = true } chrono = { version = "0.4.9", optional = true }
cookie = { version = "0.12", optional = true }
failure = "0.1.6" failure = "0.1.6"
futures = "0.1.29" futures = "0.1.29"
gotham = "0.4" gotham = "0.4"
@ -25,6 +27,7 @@ gotham_middleware_diesel = { version = "0.1", optional = true }
gotham_restful_derive = { version = "0.0.1" } gotham_restful_derive = { version = "0.0.1" }
hyper = "0.12.35" hyper = "0.12.35"
indexmap = { version = "1.3.0", optional = true } indexmap = { version = "1.3.0", optional = true }
jsonwebtoken = { version = "6.0.1", optional = true }
log = { version = "0.4.8", optional = true } log = { version = "0.4.8", optional = true }
mime = "0.3.14" mime = "0.3.14"
openapiv3 = { version = "0.3", optional = true } openapiv3 = { version = "0.3", optional = true }
@ -37,5 +40,6 @@ thiserror = "1"
[features] [features]
default = [] default = []
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
database = ["gotham_restful_derive/database", "gotham_middleware_diesel"] database = ["gotham_restful_derive/database", "gotham_middleware_diesel"]
openapi = ["gotham_restful_derive/openapi", "indexmap", "log", "openapiv3"] openapi = ["gotham_restful_derive/openapi", "indexmap", "log", "openapiv3"]

452
gotham_restful/src/auth.rs Normal file
View file

@ -0,0 +1,452 @@
use crate::HeaderName;
use cookie::CookieJar;
use futures::{future, future::Future};
use gotham::{
handler::HandlerFuture,
middleware::{Middleware, NewMiddleware},
state::{FromState, State}
};
use hyper::header::{AUTHORIZATION, HeaderMap};
use jsonwebtoken::errors::ErrorKind;
use serde::de::DeserializeOwned;
use std::{
marker::PhantomData,
panic::RefUnwindSafe
};
pub use jsonwebtoken::Validation as AuthValidation;
/// The authentication status returned by the auth middleware for each request.
#[derive(Debug, StateData)]
pub enum AuthStatus<T : Send + 'static>
{
/// The auth status is unknown.
Unknown,
/// The request has been performed without any kind of authentication.
Unauthenticated,
/// The request has been performed with an invalid authentication.
Invalid,
/// The request has been performed with an expired authentication.
Expired,
/// The request has been performed with a valid authentication.
Authenticated(T)
}
impl<T> Clone for AuthStatus<T>
where
T : Clone + Send + 'static
{
fn clone(&self) -> Self
{
match self {
Self::Unknown => Self::Unknown,
Self::Unauthenticated => Self::Unauthenticated,
Self::Invalid => Self::Invalid,
Self::Expired => Self::Expired,
Self::Authenticated(data) => Self::Authenticated(data.clone())
}
}
}
/// The source of the authentication token in the request.
#[derive(Clone, StateData)]
pub enum AuthSource
{
/// Take the token from a cookie with the given name.
Cookie(String),
/// Take the token from a header with the given name.
Header(HeaderName),
/// Take the token from the HTTP Authorization header. This is different from `Header("Authorization")`
/// as it will follow the `scheme param` format from the HTTP specification. The `scheme` will
/// be discarded, so its value doesn't matter.
AuthorizationHeader
}
/**
This trait will help the auth middleware to determine the validity of an authentication token.
A very basic implementation could look like this:
```
# use gotham_restful::{export::State, AuthHandler};
#
const SECRET : &'static [u8; 32] = b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc";
struct CustomAuthHandler;
impl<T> AuthHandler<T> for CustomAuthHandler {
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>> {
Some(SECRET.to_vec())
}
}
```
*/
pub trait AuthHandler<Data>
{
/// Return the SHA256-HMAC secret used to verify the JWT token.
fn jwt_secret<F : FnOnce() -> Option<Data>>(&self, state : &mut State, decode_data : F) -> Option<Vec<u8>>;
}
/// An `AuthHandler` returning always the same secret. See `AuthMiddleware` for a usage example.
#[derive(Clone, Debug)]
pub struct StaticAuthHandler
{
secret : Vec<u8>
}
impl StaticAuthHandler
{
pub fn from_vec(secret : Vec<u8>) -> Self
{
Self { secret }
}
pub fn from_array(secret : &[u8]) -> Self
{
Self::from_vec(secret.to_vec())
}
}
impl<T> AuthHandler<T> for StaticAuthHandler
{
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
{
Some(self.secret.clone())
}
}
/**
This is the auth middleware. To use it, first make sure you have the `auth` feature enabled. Then
simply add it to your pipeline and request it inside your handler:
```rust,no_run
# #[macro_use] extern crate gotham_restful_derive;
# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State};
# use gotham_restful::*;
# use serde::{Deserialize, Serialize};
#
#[derive(Resource)]
#[rest_resource(read_all)]
struct AuthResource;
#[derive(Debug, Deserialize)]
struct AuthData {
sub: String,
exp: u64
}
#[rest_read_all(AuthResource)]
fn read_all(auth : &AuthStatus<AuthData>) -> Success<String> {
format!("{:?}", auth).into()
}
fn main() {
let auth : AuthMiddleware<AuthData, _> = AuthMiddleware::new(
AuthSource::AuthorizationHeader,
AuthValidation::default(),
StaticAuthHandler::from_array(b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc")
);
let (chain, pipelines) = single_pipeline(new_pipeline().add(auth).build());
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
route.resource::<AuthResource, _>("auth");
}));
}
```
*/
pub struct AuthMiddleware<Data, Handler>
{
source : AuthSource,
validation : AuthValidation,
handler : Handler,
_data : PhantomData<Data>
}
impl<Data, Handler> Clone for AuthMiddleware<Data, Handler>
where Handler : Clone
{
fn clone(&self) -> Self
{
Self {
source: self.source.clone(),
validation: self.validation.clone(),
handler: self.handler.clone(),
_data: self._data
}
}
}
impl<Data, Handler> AuthMiddleware<Data, Handler>
where
Data : DeserializeOwned + Send,
Handler : AuthHandler<Data> + Default
{
pub fn from_source(source : AuthSource) -> Self
{
Self {
source,
validation: Default::default(),
handler: Default::default(),
_data: Default::default()
}
}
}
impl<Data, Handler> AuthMiddleware<Data, Handler>
where
Data : DeserializeOwned + Send,
Handler : AuthHandler<Data>
{
pub fn new(source : AuthSource, validation : AuthValidation, handler : Handler) -> Self
{
Self {
source,
validation,
handler,
_data: Default::default()
}
}
fn auth_status(&self, state : &mut State) -> AuthStatus<Data>
{
// extract the provided token, if any
let token = match &self.source {
AuthSource::Cookie(name) => {
CookieJar::try_borrow_from(&state)
.and_then(|jar| jar.get(&name))
.map(|cookie| cookie.value().to_owned())
},
AuthSource::Header(name) => {
HeaderMap::try_borrow_from(&state)
.and_then(|map| map.get(name))
.and_then(|header| header.to_str().ok())
.map(|value| value.to_owned())
},
AuthSource::AuthorizationHeader => {
HeaderMap::try_borrow_from(&state)
.and_then(|map| map.get(AUTHORIZATION))
.and_then(|header| header.to_str().ok())
.and_then(|value| value.split_whitespace().nth(1))
.map(|value| value.to_owned())
}
};
// unauthed if no token
let token = match token {
Some(token) => token,
None => return AuthStatus::Unauthenticated
};
// get the secret from the handler, possibly decoding claims ourselves
let secret = self.handler.jwt_secret(state, || {
let b64 = token.split(".").nth(1)?;
let raw = base64::decode_config(b64, base64::URL_SAFE_NO_PAD).ok()?;
serde_json::from_slice(&raw).ok()?
});
// unknown if no secret
let secret = match secret {
Some(secret) => secret,
None => return AuthStatus::Unknown
};
// validate the token
let data : Data = match jsonwebtoken::decode(&token, &secret, &self.validation) {
Ok(data) => data.claims,
Err(e) => match dbg!(e.into_kind()) {
ErrorKind::ExpiredSignature => return AuthStatus::Expired,
_ => return AuthStatus::Invalid
}
};
// we found a valid token
return AuthStatus::Authenticated(data);
}
}
impl<Data, Handler> Middleware for AuthMiddleware<Data, Handler>
where
Data : DeserializeOwned + Send + 'static,
Handler : AuthHandler<Data>
{
fn call<Chain>(self, mut state : State, chain : Chain) -> Box<HandlerFuture>
where
Chain : FnOnce(State) -> Box<HandlerFuture>
{
// put the source in our state, required for e.g. openapi
state.put(self.source.clone());
// put the status in our state
let status = self.auth_status(&mut state);
state.put(status);
// call the rest of the chain
Box::new(chain(state).and_then(|(state, res)| future::ok((state, res))))
}
}
impl<Data, Handler> NewMiddleware for AuthMiddleware<Data, Handler>
where
Self : Clone + Middleware + Sync + RefUnwindSafe
{
type Instance = Self;
fn new_middleware(&self) -> Result<Self::Instance, std::io::Error>
{
let c : Self = self.clone();
Ok(c)
}
}
#[cfg(test)]
mod test
{
use super::*;
use cookie::Cookie;
use std::fmt::Debug;
// 256-bit random string
const JWT_SECRET : &'static [u8; 32] = b"Lyzsfnta0cdxyF0T9y6VGxp3jpgoMUuW";
// some known tokens
const VALID_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjQxMDI0NDQ4MDB9.8h8Ax-nnykqEQ62t7CxmM3ja6NzUQ4L0MLOOzddjLKk";
const EXPIRED_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjE1Nzc4MzcxMDB9.eV1snaGLYrJ7qUoMk74OvBY3WUU9M0Je5HTU2xtX1v0";
#[derive(Debug, Deserialize, PartialEq)]
struct TestData
{
iss : String,
sub : String,
iat : u64,
exp : u64
}
impl Default for TestData
{
fn default() -> Self
{
Self {
iss: "msrd0".to_owned(),
sub: "gotham-restful".to_owned(),
iat: 1577836800,
exp: 4102444800
}
}
}
#[derive(Default)]
struct TestHandler;
impl<T> AuthHandler<T> for TestHandler
{
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
{
Some(JWT_SECRET.to_vec())
}
}
#[derive(Default)]
struct TestAssertingHandler;
impl<T> AuthHandler<T> for TestAssertingHandler
where T : Debug + Default + PartialEq
{
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, decode_data : F) -> Option<Vec<u8>>
{
assert_eq!(decode_data(), Some(T::default()));
Some(JWT_SECRET.to_vec())
}
}
#[test]
fn test_auth_middleware_decode_data()
{
let middleware = <AuthMiddleware<TestData, TestAssertingHandler>>::from_source(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap());
state.put(headers);
middleware.auth_status(&mut state);
});
}
fn new_middleware<T>(source : AuthSource) -> AuthMiddleware<T, TestHandler>
where T : DeserializeOwned + Send
{
AuthMiddleware::from_source(source)
}
#[test]
fn test_auth_middleware_no_token()
{
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Unauthenticated => {},
_ => panic!("Expected AuthStatus::Unauthenticated, got {:?}", status)
};
});
}
#[test]
fn test_auth_middleware_expired_token()
{
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, format!("Bearer {}", EXPIRED_TOKEN).parse().unwrap());
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Expired => {},
_ => panic!("Expected AuthStatus::Expired, got {:?}", status)
};
});
}
#[test]
fn test_auth_middleware_auth_header_token()
{
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap());
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
};
})
}
#[test]
fn test_auth_middleware_header_token()
{
let header_name = "x-znoiprwmvfexju";
let middleware = new_middleware::<TestData>(AuthSource::Header(HeaderName::from_static(header_name)));
State::with_new(|mut state| {
let mut headers = HeaderMap::new();
headers.insert(header_name, VALID_TOKEN.parse().unwrap());
state.put(headers);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
};
})
}
#[test]
fn test_auth_middleware_cookie_token()
{
let cookie_name = "znoiprwmvfexju";
let middleware = new_middleware::<TestData>(AuthSource::Cookie(cookie_name.to_owned()));
State::with_new(|mut state| {
let mut jar = CookieJar::new();
jar.add_original(Cookie::new(cookie_name, VALID_TOKEN));
state.put(jar);
let status = middleware.auth_status(&mut state);
match status {
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
};
})
}
}

View file

@ -8,7 +8,8 @@ resources.
# Usage # Usage
To use this crate, add the following to your `Cargo.toml`: This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the
following to your `Cargo.toml`:
```toml ```toml
[dependencies] [dependencies]
@ -112,7 +113,7 @@ extern crate self as gotham_restful;
#[macro_use] extern crate serde; #[macro_use] extern crate serde;
#[doc(no_inline)] #[doc(no_inline)]
pub use hyper::{Chunk, StatusCode}; pub use hyper::{header::HeaderName, Chunk, StatusCode};
#[doc(no_inline)] #[doc(no_inline)]
pub use mime::Mime; pub use mime::Mime;
@ -134,6 +135,18 @@ pub mod export
pub use openapiv3 as openapi; pub use openapiv3 as openapi;
} }
#[cfg(feature = "auth")]
mod auth;
#[cfg(feature = "auth")]
pub use auth::{
AuthHandler,
AuthMiddleware,
AuthSource,
AuthStatus,
AuthValidation,
StaticAuthHandler
};
#[cfg(feature = "openapi")] #[cfg(feature = "openapi")]
mod openapi; mod openapi;
#[cfg(feature = "openapi")] #[cfg(feature = "openapi")]

View file

@ -21,9 +21,9 @@ use indexmap::IndexMap;
use log::error; use log::error;
use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN}; use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN};
use openapiv3::{ use openapiv3::{
Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, APIKeyLocation, Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem,
Paths, ReferenceOr, ReferenceOr::Item, ReferenceOr::Reference, RequestBody as OARequestBody, Response, Responses, Schema, Paths, ReferenceOr, ReferenceOr::Item, ReferenceOr::Reference, RequestBody as OARequestBody, Response, Responses, Schema,
SchemaKind, Server, StatusCode, Type SchemaKind, SecurityRequirement, SecurityScheme, Server, StatusCode, Type
}; };
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::panic::RefUnwindSafe; use std::panic::RefUnwindSafe;
@ -125,16 +125,13 @@ impl OpenapiRouter
} }
#[derive(Clone)] #[derive(Clone)]
struct OpenapiHandler(Result<String, String>); struct OpenapiHandler(OpenAPI);
// dunno what/why/whatever
impl RefUnwindSafe for OpenapiHandler {}
impl OpenapiHandler impl OpenapiHandler
{ {
fn new(openapi : &OpenapiRouter) -> Self fn new(openapi : &OpenapiRouter) -> Self
{ {
Self(serde_json::to_string(&openapi.0).map_err(|e| format!("{}", e))) Self(openapi.0.clone())
} }
} }
@ -148,11 +145,63 @@ impl NewHandler for OpenapiHandler
} }
} }
#[cfg(feature = "auth")]
const SECURITY_NAME : &'static str = "authToken";
#[cfg(feature = "auth")]
fn get_security(state : &mut State) -> (Vec<SecurityRequirement>, IndexMap<String, ReferenceOr<SecurityScheme>>)
{
use crate::AuthSource;
use gotham::state::FromState;
let source = match AuthSource::try_borrow_from(state) {
Some(source) => source,
None => return Default::default()
};
let mut security : IndexMap<String, Vec<String>> = Default::default();
security.insert(SECURITY_NAME.to_owned(), Vec::new());
let security = vec![security];
let security_scheme = match source {
AuthSource::Cookie(name) => SecurityScheme::APIKey {
location: APIKeyLocation::Cookie,
name: name.to_string()
},
AuthSource::Header(name) => SecurityScheme::APIKey {
location: APIKeyLocation::Header,
name: name.to_string()
},
AuthSource::AuthorizationHeader => SecurityScheme::HTTP {
scheme: "bearer".to_owned(),
bearer_format: Some("JWT".to_owned())
}
};
let mut security_schemes : IndexMap<String, ReferenceOr<SecurityScheme>> = Default::default();
security_schemes.insert(SECURITY_NAME.to_owned(), ReferenceOr::Item(security_scheme));
(security, security_schemes)
}
#[cfg(not(feature = "auth"))]
fn get_security(state : &mut State) -> (Vec<SecurityRequirement>, IndexMap<String, ReferenceOr<SecurityScheme>>)
{
Default::default()
}
impl Handler for OpenapiHandler impl Handler for OpenapiHandler
{ {
fn handle(self, state : State) -> Box<HandlerFuture> fn handle(self, mut state : State) -> Box<HandlerFuture>
{ {
match self.0 { let mut openapi = self.0;
let (security, security_schemes) = get_security(&mut state);
openapi.security = security;
let mut components = openapi.components.unwrap_or_default();
components.security_schemes = security_schemes;
openapi.components = Some(components);
match serde_json::to_string(&openapi) {
Ok(body) => { Ok(body) => {
let res = create_response(&state, hyper::StatusCode::OK, APPLICATION_JSON, body); let res = create_response(&state, hyper::StatusCode::OK, APPLICATION_JSON, body);
Box::new(ok((state, res))) Box::new(ok((state, res)))

View file

@ -25,5 +25,6 @@ syn = { version = "1.0.13", features = ["extra-traits", "full"] }
[features] [features]
default = [] default = []
auth = []
database = [] database = []
openapi = [] openapi = []

View file

@ -3,8 +3,10 @@ use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2}; use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
Attribute,
FnArg, FnArg,
ItemFn, ItemFn,
PatType,
ReturnType, ReturnType,
Type, Type,
parse_macro_input parse_macro_input
@ -84,6 +86,112 @@ impl Method
} }
} }
enum MethodArgumentType
{
StateRef,
StateMutRef,
MethodArg(Type),
DatabaseConnection(Type),
AuthStatus(Type),
AuthStatusRef(Type)
}
impl MethodArgumentType
{
fn is_method_arg(&self) -> bool
{
match self {
Self::MethodArg(_) => true,
_ => false,
}
}
fn is_database_conn(&self) -> bool
{
match self {
Self::DatabaseConnection(_) => true,
_ => false
}
}
fn is_auth_status(&self) -> bool
{
match self {
Self::AuthStatus(_) | Self::AuthStatusRef(_) => true,
_ => false
}
}
fn is_auth_status_ref(&self) -> bool
{
match self {
Self::AuthStatusRef(_) => true,
_ => false
}
}
fn quote_ty(&self) -> Option<TokenStream2>
{
match self {
Self::MethodArg(ty) => Some(quote!(#ty)),
Self::DatabaseConnection(ty) => Some(quote!(#ty)),
Self::AuthStatus(ty) => Some(quote!(#ty)),
Self::AuthStatusRef(ty) => Some(quote!(#ty)),
_ => None
}
}
}
struct MethodArgument
{
ident : Ident,
ty : MethodArgumentType
}
fn interpret_arg_ty(index : usize, attrs : &[Attribute], name : &str, ty : Type) -> MethodArgumentType
{
let attr = attrs.into_iter()
.filter(|arg| arg.path.segments.iter().filter(|path| &path.ident.to_string() == "rest_arg").nth(0).is_some())
.nth(0)
.map(|arg| arg.tokens.to_string());
if cfg!(feature = "auth") && (attr.as_deref() == Some("auth") || (attr.is_none() && name == "auth"))
{
return match ty {
Type::Reference(ty) => MethodArgumentType::AuthStatusRef(*ty.elem),
ty => MethodArgumentType::AuthStatus(ty)
};
}
if cfg!(feature = "database") && (attr.as_deref() == Some("connection") || attr.as_deref() == Some("conn") || (attr.is_none() && name == "conn"))
{
return MethodArgumentType::DatabaseConnection(match ty {
Type::Reference(ty) => *ty.elem,
ty => ty
});
}
if index == 0
{
return match ty {
Type::Reference(ty) => if ty.mutability.is_none() { MethodArgumentType::StateRef } else { MethodArgumentType::StateMutRef },
_ => panic!("The first argument, unless some feature is used, has to be a (mutable) reference to gotham::state::State")
};
}
MethodArgumentType::MethodArg(ty)
}
fn interpret_arg(index : usize, arg : &PatType) -> MethodArgument
{
let pat = &arg.pat;
let ident = format_ident!("arg{}", index);
let orig_name = quote!(#pat);
let ty = interpret_arg_ty(index, &arg.attrs, &orig_name.to_string(), *arg.ty.clone());
MethodArgument { ident, ty }
}
pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -> TokenStream pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -> TokenStream
{ {
let krate = super::krate(); let krate = super::krate();
@ -101,71 +209,90 @@ pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -
ReturnType::Type(_, ty) => (quote!(#ty), false) ReturnType::Type(_, ty) => (quote!(#ty), false)
}; };
// extract arguments into pattern, ident and type // some default idents we'll need
let state_ident = format_ident!("state"); let state_ident = format_ident!("state");
let args : Vec<(usize, TokenStream2, Ident, Type)> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg { let repo_ident = format_ident!("repo");
FnArg::Typed(arg) => { let conn_ident = format_ident!("conn");
let pat = &arg.pat; let auth_ident = format_ident!("auth");
let ident = format_ident!("arg{}", i);
(i, quote!(#pat), ident, *arg.ty.clone()) // extract arguments into pattern, ident and type
}, let args : Vec<MethodArgument> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
FnArg::Typed(arg) => interpret_arg(i, arg),
FnArg::Receiver(_) => panic!("didn't expect self parameter") FnArg::Receiver(_) => panic!("didn't expect self parameter")
}).collect(); }).collect();
// find the database connection if enabled and present
let repo_ident = format_ident!("database_repo");
let args_conn = if cfg!(feature = "database") {
args.iter().filter(|(_, pat, _, _)| pat.to_string() == "conn").nth(0)
} else { None };
let args_conn_name = args_conn.map(|(_, pat, _, _)| pat.to_string());
// extract the generic parameters to use // extract the generic parameters to use
let mut generics : Vec<TokenStream2> = args.iter().skip(1) let mut generics : Vec<TokenStream2> = args.iter()
.filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name) .filter(|arg| (*arg).ty.is_method_arg())
.map(|(_, _, _, ty)| quote!(#ty)).collect(); .map(|arg| arg.ty.quote_ty().unwrap())
.collect();
generics.push(quote!(#ret)); generics.push(quote!(#ret));
// extract the definition of our method // extract the definition of our method
let mut args_def : Vec<TokenStream2> = args.iter() let mut args_def : Vec<TokenStream2> = args.iter()
.filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name) .filter(|arg| (*arg).ty.is_method_arg())
.map(|(i, _, ident, ty)| if *i == 0 { quote!(#state_ident : #ty) } else { quote!(#ident : #ty) }).collect(); .map(|arg| {
if let Some(_) = args_conn let ident = &arg.ident;
{ let ty = arg.ty.quote_ty();
quote!(#ident : #ty)
}).collect();
args_def.insert(0, quote!(#state_ident : &mut #krate::export::State)); args_def.insert(0, quote!(#state_ident : &mut #krate::export::State));
}
// extract the arguments to pass over to the supplied method // extract the arguments to pass over to the supplied method
let args_pass : Vec<TokenStream2> = args.iter().map(|(i, pat, ident, _)| if Some(pat.to_string()) != args_conn_name { let args_pass : Vec<TokenStream2> = args.iter().map(|arg| match (&arg.ty, &arg.ident) {
if *i == 0 { quote!(#state_ident) } else { quote!(#ident) } (MethodArgumentType::StateRef, _) => quote!(#state_ident),
} else { (MethodArgumentType::StateMutRef, _) => quote!(#state_ident),
quote!(&#ident) (MethodArgumentType::MethodArg(_), ident) => quote!(#ident),
(MethodArgumentType::DatabaseConnection(_), _) => quote!(&#conn_ident),
(MethodArgumentType::AuthStatus(_), _) => quote!(#auth_ident.clone()),
(MethodArgumentType::AuthStatusRef(_), _) => quote!(#auth_ident)
}).collect(); }).collect();
// prepare the method block // prepare the method block
let mut block = if is_no_content { quote!(#fun_ident(#(#args_pass),*); Default::default()) } else { quote!(#fun_ident(#(#args_pass),*)) }; let mut block = quote!(#fun_ident(#(#args_pass),*));
if /*cfg!(feature = "database") &&*/ let Some((_, _, conn_ident, conn_ty)) = args_conn // https://github.com/rust-lang/rust/issues/53667 if is_no_content
{ {
let conn_ty_real = match conn_ty { block = quote!(#block; Default::default())
Type::Reference(ty) => &*ty.elem, }
ty => ty if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_database_conn()).nth(0)
}; {
let conn_ty = arg.ty.quote_ty();
block = quote! { block = quote! {
use #krate::export::{Future, FromState}; let #repo_ident = <#krate::export::Repo<#conn_ty>>::borrow_from(&#state_ident).clone();
let #repo_ident = <#krate::export::Repo<#conn_ty_real>>::borrow_from(&#state_ident).clone();
#repo_ident.run::<_, #ret, ()>(move |#conn_ident| { #repo_ident.run::<_, #ret, ()>(move |#conn_ident| {
Ok({#block}) Ok({#block})
}).wait().unwrap() }).wait().unwrap()
}; };
} }
if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_auth_status()).nth(0)
{
let auth_ty = arg.ty.quote_ty();
block = quote! {
let #auth_ident : &#auth_ty = <#auth_ty>::borrow_from(#state_ident);
#block
};
}
// prepare the where clause
let mut where_clause = quote!(#resource_ident : #krate::Resource,);
for arg in args.iter().filter(|arg| (*arg).ty.is_auth_status() && !(*arg).ty.is_auth_status_ref())
{
let auth_ty = arg.ty.quote_ty();
where_clause = quote!(#where_clause #auth_ty : Clone,);
}
// put everything together
let output = quote! { let output = quote! {
#fun #fun
impl #krate::#trait_ident<#(#generics),*> for #resource_ident impl #krate::#trait_ident<#(#generics),*> for #resource_ident
where #resource_ident : #krate::Resource where #where_clause
{ {
fn #method_ident(#(#args_def),*) -> #ret fn #method_ident(#(#args_def),*) -> #ret
{ {
#[allow(unused_imports)]
use #krate::export::{Future, FromState};
#block #block
} }
} }