mirror of
https://gitlab.com/msrd0/gotham-restful.git
synced 2025-02-23 04:52:28 +00:00
implement auth parsing/verifying inside a gotham middleware
This commit is contained in:
parent
c025cbd8ea
commit
088774fc50
9 changed files with 721 additions and 49 deletions
|
@ -12,7 +12,8 @@ resources.
|
|||
|
||||
## Usage
|
||||
|
||||
To use this crate, add the following to your `Cargo.toml`:
|
||||
This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the
|
||||
following to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
|
|
|
@ -17,7 +17,7 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" }
|
|||
fake = "2.2"
|
||||
gotham = "0.4"
|
||||
gotham_derive = "0.4"
|
||||
gotham_restful = { version = "0.0.1", features = ["openapi"] }
|
||||
gotham_restful = { version = "0.0.1", features = ["auth", "openapi"] }
|
||||
hyper = "0.12"
|
||||
log = "0.4"
|
||||
log4rs = { version = "0.8", features = ["console_appender"], default-features = false }
|
||||
|
|
|
@ -23,6 +23,12 @@ struct Users
|
|||
{
|
||||
}
|
||||
|
||||
#[derive(Resource)]
|
||||
#[rest_resource(ReadAll)]
|
||||
struct Auth
|
||||
{
|
||||
}
|
||||
|
||||
#[derive(Deserialize, OpenapiType, Serialize, StateData, StaticResponseExtender)]
|
||||
struct User
|
||||
{
|
||||
|
@ -82,8 +88,24 @@ fn delete(_state : &mut State, id : u64)
|
|||
info!("Delete User {}", id);
|
||||
}
|
||||
|
||||
#[rest_read_all(Auth)]
|
||||
fn auth_read_all(auth : AuthStatus<()>) -> Success<String>
|
||||
{
|
||||
format!("{:?}", auth).into()
|
||||
}
|
||||
|
||||
const ADDR : &str = "127.0.0.1:18080";
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Handler;
|
||||
impl<T> AuthHandler<T> for Handler
|
||||
{
|
||||
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn main()
|
||||
{
|
||||
let encoder = PatternEncoder::new("{d(%Y-%m-%d %H:%M:%S%.3f %Z)} [{l}] {M} - {m}\n");
|
||||
|
@ -99,9 +121,11 @@ fn main()
|
|||
.unwrap();
|
||||
log4rs::init_config(config).unwrap();
|
||||
|
||||
let auth = <AuthMiddleware<(), Handler>>::from_source(AuthSource::AuthorizationHeader);
|
||||
let logging = RequestLogger::new(log::Level::Info);
|
||||
let (chain, pipelines) = single_pipeline(
|
||||
new_pipeline()
|
||||
.add(auth)
|
||||
.add(logging)
|
||||
.build()
|
||||
);
|
||||
|
@ -109,6 +133,7 @@ fn main()
|
|||
gotham::start(ADDR, build_router(chain, pipelines, |route| {
|
||||
route.with_openapi("Users Example", "0.0.1", format!("http://{}", ADDR), |mut route| {
|
||||
route.resource::<Users, _>("users");
|
||||
route.resource::<Auth, _>("auth");
|
||||
route.get_openapi("openapi");
|
||||
});
|
||||
}));
|
||||
|
|
|
@ -16,7 +16,9 @@ gitlab = { repository = "msrd0/gotham-restful", branch = "master" }
|
|||
codecov = { repository = "msrd0/gotham-restful", branch = "master", service = "gitlab" }
|
||||
|
||||
[dependencies]
|
||||
base64 = { version = "0.11", optional = true }
|
||||
chrono = { version = "0.4.9", optional = true }
|
||||
cookie = { version = "0.12", optional = true }
|
||||
failure = "0.1.6"
|
||||
futures = "0.1.29"
|
||||
gotham = "0.4"
|
||||
|
@ -25,6 +27,7 @@ gotham_middleware_diesel = { version = "0.1", optional = true }
|
|||
gotham_restful_derive = { version = "0.0.1" }
|
||||
hyper = "0.12.35"
|
||||
indexmap = { version = "1.3.0", optional = true }
|
||||
jsonwebtoken = { version = "6.0.1", optional = true }
|
||||
log = { version = "0.4.8", optional = true }
|
||||
mime = "0.3.14"
|
||||
openapiv3 = { version = "0.3", optional = true }
|
||||
|
@ -37,5 +40,6 @@ thiserror = "1"
|
|||
|
||||
[features]
|
||||
default = []
|
||||
auth = ["gotham_restful_derive/auth", "base64", "cookie", "jsonwebtoken"]
|
||||
database = ["gotham_restful_derive/database", "gotham_middleware_diesel"]
|
||||
openapi = ["gotham_restful_derive/openapi", "indexmap", "log", "openapiv3"]
|
||||
|
|
452
gotham_restful/src/auth.rs
Normal file
452
gotham_restful/src/auth.rs
Normal file
|
@ -0,0 +1,452 @@
|
|||
use crate::HeaderName;
|
||||
use cookie::CookieJar;
|
||||
use futures::{future, future::Future};
|
||||
use gotham::{
|
||||
handler::HandlerFuture,
|
||||
middleware::{Middleware, NewMiddleware},
|
||||
state::{FromState, State}
|
||||
};
|
||||
use hyper::header::{AUTHORIZATION, HeaderMap};
|
||||
use jsonwebtoken::errors::ErrorKind;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
panic::RefUnwindSafe
|
||||
};
|
||||
|
||||
pub use jsonwebtoken::Validation as AuthValidation;
|
||||
|
||||
/// The authentication status returned by the auth middleware for each request.
|
||||
#[derive(Debug, StateData)]
|
||||
pub enum AuthStatus<T : Send + 'static>
|
||||
{
|
||||
/// The auth status is unknown.
|
||||
Unknown,
|
||||
/// The request has been performed without any kind of authentication.
|
||||
Unauthenticated,
|
||||
/// The request has been performed with an invalid authentication.
|
||||
Invalid,
|
||||
/// The request has been performed with an expired authentication.
|
||||
Expired,
|
||||
/// The request has been performed with a valid authentication.
|
||||
Authenticated(T)
|
||||
}
|
||||
|
||||
impl<T> Clone for AuthStatus<T>
|
||||
where
|
||||
T : Clone + Send + 'static
|
||||
{
|
||||
fn clone(&self) -> Self
|
||||
{
|
||||
match self {
|
||||
Self::Unknown => Self::Unknown,
|
||||
Self::Unauthenticated => Self::Unauthenticated,
|
||||
Self::Invalid => Self::Invalid,
|
||||
Self::Expired => Self::Expired,
|
||||
Self::Authenticated(data) => Self::Authenticated(data.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The source of the authentication token in the request.
|
||||
#[derive(Clone, StateData)]
|
||||
pub enum AuthSource
|
||||
{
|
||||
/// Take the token from a cookie with the given name.
|
||||
Cookie(String),
|
||||
/// Take the token from a header with the given name.
|
||||
Header(HeaderName),
|
||||
/// Take the token from the HTTP Authorization header. This is different from `Header("Authorization")`
|
||||
/// as it will follow the `scheme param` format from the HTTP specification. The `scheme` will
|
||||
/// be discarded, so its value doesn't matter.
|
||||
AuthorizationHeader
|
||||
}
|
||||
|
||||
/**
|
||||
This trait will help the auth middleware to determine the validity of an authentication token.
|
||||
|
||||
A very basic implementation could look like this:
|
||||
```
|
||||
# use gotham_restful::{export::State, AuthHandler};
|
||||
#
|
||||
const SECRET : &'static [u8; 32] = b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc";
|
||||
|
||||
struct CustomAuthHandler;
|
||||
impl<T> AuthHandler<T> for CustomAuthHandler {
|
||||
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>> {
|
||||
Some(SECRET.to_vec())
|
||||
}
|
||||
}
|
||||
```
|
||||
*/
|
||||
pub trait AuthHandler<Data>
|
||||
{
|
||||
/// Return the SHA256-HMAC secret used to verify the JWT token.
|
||||
fn jwt_secret<F : FnOnce() -> Option<Data>>(&self, state : &mut State, decode_data : F) -> Option<Vec<u8>>;
|
||||
}
|
||||
|
||||
/// An `AuthHandler` returning always the same secret. See `AuthMiddleware` for a usage example.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StaticAuthHandler
|
||||
{
|
||||
secret : Vec<u8>
|
||||
}
|
||||
|
||||
impl StaticAuthHandler
|
||||
{
|
||||
pub fn from_vec(secret : Vec<u8>) -> Self
|
||||
{
|
||||
Self { secret }
|
||||
}
|
||||
|
||||
pub fn from_array(secret : &[u8]) -> Self
|
||||
{
|
||||
Self::from_vec(secret.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AuthHandler<T> for StaticAuthHandler
|
||||
{
|
||||
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
|
||||
{
|
||||
Some(self.secret.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
This is the auth middleware. To use it, first make sure you have the `auth` feature enabled. Then
|
||||
simply add it to your pipeline and request it inside your handler:
|
||||
|
||||
```rust,no_run
|
||||
# #[macro_use] extern crate gotham_restful_derive;
|
||||
# use gotham::{router::builder::*, pipeline::{new_pipeline, single::single_pipeline}, state::State};
|
||||
# use gotham_restful::*;
|
||||
# use serde::{Deserialize, Serialize};
|
||||
#
|
||||
#[derive(Resource)]
|
||||
#[rest_resource(read_all)]
|
||||
struct AuthResource;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthData {
|
||||
sub: String,
|
||||
exp: u64
|
||||
}
|
||||
|
||||
#[rest_read_all(AuthResource)]
|
||||
fn read_all(auth : &AuthStatus<AuthData>) -> Success<String> {
|
||||
format!("{:?}", auth).into()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let auth : AuthMiddleware<AuthData, _> = AuthMiddleware::new(
|
||||
AuthSource::AuthorizationHeader,
|
||||
AuthValidation::default(),
|
||||
StaticAuthHandler::from_array(b"zlBsA2QXnkmpe0QTh8uCvtAEa4j33YAc")
|
||||
);
|
||||
let (chain, pipelines) = single_pipeline(new_pipeline().add(auth).build());
|
||||
gotham::start("127.0.0.1:8080", build_router(chain, pipelines, |route| {
|
||||
route.resource::<AuthResource, _>("auth");
|
||||
}));
|
||||
}
|
||||
```
|
||||
*/
|
||||
pub struct AuthMiddleware<Data, Handler>
|
||||
{
|
||||
source : AuthSource,
|
||||
validation : AuthValidation,
|
||||
handler : Handler,
|
||||
_data : PhantomData<Data>
|
||||
}
|
||||
|
||||
impl<Data, Handler> Clone for AuthMiddleware<Data, Handler>
|
||||
where Handler : Clone
|
||||
{
|
||||
fn clone(&self) -> Self
|
||||
{
|
||||
Self {
|
||||
source: self.source.clone(),
|
||||
validation: self.validation.clone(),
|
||||
handler: self.handler.clone(),
|
||||
_data: self._data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Data, Handler> AuthMiddleware<Data, Handler>
|
||||
where
|
||||
Data : DeserializeOwned + Send,
|
||||
Handler : AuthHandler<Data> + Default
|
||||
{
|
||||
pub fn from_source(source : AuthSource) -> Self
|
||||
{
|
||||
Self {
|
||||
source,
|
||||
validation: Default::default(),
|
||||
handler: Default::default(),
|
||||
_data: Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Data, Handler> AuthMiddleware<Data, Handler>
|
||||
where
|
||||
Data : DeserializeOwned + Send,
|
||||
Handler : AuthHandler<Data>
|
||||
{
|
||||
pub fn new(source : AuthSource, validation : AuthValidation, handler : Handler) -> Self
|
||||
{
|
||||
Self {
|
||||
source,
|
||||
validation,
|
||||
handler,
|
||||
_data: Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_status(&self, state : &mut State) -> AuthStatus<Data>
|
||||
{
|
||||
// extract the provided token, if any
|
||||
let token = match &self.source {
|
||||
AuthSource::Cookie(name) => {
|
||||
CookieJar::try_borrow_from(&state)
|
||||
.and_then(|jar| jar.get(&name))
|
||||
.map(|cookie| cookie.value().to_owned())
|
||||
},
|
||||
AuthSource::Header(name) => {
|
||||
HeaderMap::try_borrow_from(&state)
|
||||
.and_then(|map| map.get(name))
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.map(|value| value.to_owned())
|
||||
},
|
||||
AuthSource::AuthorizationHeader => {
|
||||
HeaderMap::try_borrow_from(&state)
|
||||
.and_then(|map| map.get(AUTHORIZATION))
|
||||
.and_then(|header| header.to_str().ok())
|
||||
.and_then(|value| value.split_whitespace().nth(1))
|
||||
.map(|value| value.to_owned())
|
||||
}
|
||||
};
|
||||
|
||||
// unauthed if no token
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
None => return AuthStatus::Unauthenticated
|
||||
};
|
||||
|
||||
// get the secret from the handler, possibly decoding claims ourselves
|
||||
let secret = self.handler.jwt_secret(state, || {
|
||||
let b64 = token.split(".").nth(1)?;
|
||||
let raw = base64::decode_config(b64, base64::URL_SAFE_NO_PAD).ok()?;
|
||||
serde_json::from_slice(&raw).ok()?
|
||||
});
|
||||
|
||||
// unknown if no secret
|
||||
let secret = match secret {
|
||||
Some(secret) => secret,
|
||||
None => return AuthStatus::Unknown
|
||||
};
|
||||
|
||||
// validate the token
|
||||
let data : Data = match jsonwebtoken::decode(&token, &secret, &self.validation) {
|
||||
Ok(data) => data.claims,
|
||||
Err(e) => match dbg!(e.into_kind()) {
|
||||
ErrorKind::ExpiredSignature => return AuthStatus::Expired,
|
||||
_ => return AuthStatus::Invalid
|
||||
}
|
||||
};
|
||||
|
||||
// we found a valid token
|
||||
return AuthStatus::Authenticated(data);
|
||||
}
|
||||
}
|
||||
|
||||
impl<Data, Handler> Middleware for AuthMiddleware<Data, Handler>
|
||||
where
|
||||
Data : DeserializeOwned + Send + 'static,
|
||||
Handler : AuthHandler<Data>
|
||||
{
|
||||
fn call<Chain>(self, mut state : State, chain : Chain) -> Box<HandlerFuture>
|
||||
where
|
||||
Chain : FnOnce(State) -> Box<HandlerFuture>
|
||||
{
|
||||
// put the source in our state, required for e.g. openapi
|
||||
state.put(self.source.clone());
|
||||
|
||||
// put the status in our state
|
||||
let status = self.auth_status(&mut state);
|
||||
state.put(status);
|
||||
|
||||
// call the rest of the chain
|
||||
Box::new(chain(state).and_then(|(state, res)| future::ok((state, res))))
|
||||
}
|
||||
}
|
||||
|
||||
impl<Data, Handler> NewMiddleware for AuthMiddleware<Data, Handler>
|
||||
where
|
||||
Self : Clone + Middleware + Sync + RefUnwindSafe
|
||||
{
|
||||
type Instance = Self;
|
||||
|
||||
fn new_middleware(&self) -> Result<Self::Instance, std::io::Error>
|
||||
{
|
||||
let c : Self = self.clone();
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test
|
||||
{
|
||||
use super::*;
|
||||
use cookie::Cookie;
|
||||
use std::fmt::Debug;
|
||||
|
||||
// 256-bit random string
|
||||
const JWT_SECRET : &'static [u8; 32] = b"Lyzsfnta0cdxyF0T9y6VGxp3jpgoMUuW";
|
||||
|
||||
// some known tokens
|
||||
const VALID_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjQxMDI0NDQ4MDB9.8h8Ax-nnykqEQ62t7CxmM3ja6NzUQ4L0MLOOzddjLKk";
|
||||
const EXPIRED_TOKEN : &'static str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJtc3JkMCIsInN1YiI6ImdvdGhhbS1yZXN0ZnVsIiwiaWF0IjoxNTc3ODM2ODAwLCJleHAiOjE1Nzc4MzcxMDB9.eV1snaGLYrJ7qUoMk74OvBY3WUU9M0Je5HTU2xtX1v0";
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq)]
|
||||
struct TestData
|
||||
{
|
||||
iss : String,
|
||||
sub : String,
|
||||
iat : u64,
|
||||
exp : u64
|
||||
}
|
||||
|
||||
impl Default for TestData
|
||||
{
|
||||
fn default() -> Self
|
||||
{
|
||||
Self {
|
||||
iss: "msrd0".to_owned(),
|
||||
sub: "gotham-restful".to_owned(),
|
||||
iat: 1577836800,
|
||||
exp: 4102444800
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestHandler;
|
||||
impl<T> AuthHandler<T> for TestHandler
|
||||
{
|
||||
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, _decode_data : F) -> Option<Vec<u8>>
|
||||
{
|
||||
Some(JWT_SECRET.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestAssertingHandler;
|
||||
impl<T> AuthHandler<T> for TestAssertingHandler
|
||||
where T : Debug + Default + PartialEq
|
||||
{
|
||||
fn jwt_secret<F : FnOnce() -> Option<T>>(&self, _state : &mut State, decode_data : F) -> Option<Vec<u8>>
|
||||
{
|
||||
assert_eq!(decode_data(), Some(T::default()));
|
||||
Some(JWT_SECRET.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_decode_data()
|
||||
{
|
||||
let middleware = <AuthMiddleware<TestData, TestAssertingHandler>>::from_source(AuthSource::AuthorizationHeader);
|
||||
State::with_new(|mut state| {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap());
|
||||
state.put(headers);
|
||||
middleware.auth_status(&mut state);
|
||||
});
|
||||
}
|
||||
|
||||
fn new_middleware<T>(source : AuthSource) -> AuthMiddleware<T, TestHandler>
|
||||
where T : DeserializeOwned + Send
|
||||
{
|
||||
AuthMiddleware::from_source(source)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_no_token()
|
||||
{
|
||||
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
|
||||
State::with_new(|mut state| {
|
||||
let status = middleware.auth_status(&mut state);
|
||||
match status {
|
||||
AuthStatus::Unauthenticated => {},
|
||||
_ => panic!("Expected AuthStatus::Unauthenticated, got {:?}", status)
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_expired_token()
|
||||
{
|
||||
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
|
||||
State::with_new(|mut state| {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(AUTHORIZATION, format!("Bearer {}", EXPIRED_TOKEN).parse().unwrap());
|
||||
state.put(headers);
|
||||
let status = middleware.auth_status(&mut state);
|
||||
match status {
|
||||
AuthStatus::Expired => {},
|
||||
_ => panic!("Expected AuthStatus::Expired, got {:?}", status)
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_auth_header_token()
|
||||
{
|
||||
let middleware = new_middleware::<TestData>(AuthSource::AuthorizationHeader);
|
||||
State::with_new(|mut state| {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(AUTHORIZATION, format!("Bearer {}", VALID_TOKEN).parse().unwrap());
|
||||
state.put(headers);
|
||||
let status = middleware.auth_status(&mut state);
|
||||
match status {
|
||||
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
|
||||
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_header_token()
|
||||
{
|
||||
let header_name = "x-znoiprwmvfexju";
|
||||
let middleware = new_middleware::<TestData>(AuthSource::Header(HeaderName::from_static(header_name)));
|
||||
State::with_new(|mut state| {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header_name, VALID_TOKEN.parse().unwrap());
|
||||
state.put(headers);
|
||||
let status = middleware.auth_status(&mut state);
|
||||
match status {
|
||||
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
|
||||
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auth_middleware_cookie_token()
|
||||
{
|
||||
let cookie_name = "znoiprwmvfexju";
|
||||
let middleware = new_middleware::<TestData>(AuthSource::Cookie(cookie_name.to_owned()));
|
||||
State::with_new(|mut state| {
|
||||
let mut jar = CookieJar::new();
|
||||
jar.add_original(Cookie::new(cookie_name, VALID_TOKEN));
|
||||
state.put(jar);
|
||||
let status = middleware.auth_status(&mut state);
|
||||
match status {
|
||||
AuthStatus::Authenticated(data) => assert_eq!(data, TestData::default()),
|
||||
_ => panic!("Expected AuthStatus::Authenticated, got {:?}", status)
|
||||
};
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,7 +8,8 @@ resources.
|
|||
|
||||
# Usage
|
||||
|
||||
To use this crate, add the following to your `Cargo.toml`:
|
||||
This crate targets stable rust, currently requiring rustc 1.40+. To use this crate, add the
|
||||
following to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
|
@ -112,7 +113,7 @@ extern crate self as gotham_restful;
|
|||
#[macro_use] extern crate serde;
|
||||
|
||||
#[doc(no_inline)]
|
||||
pub use hyper::{Chunk, StatusCode};
|
||||
pub use hyper::{header::HeaderName, Chunk, StatusCode};
|
||||
#[doc(no_inline)]
|
||||
pub use mime::Mime;
|
||||
|
||||
|
@ -134,6 +135,18 @@ pub mod export
|
|||
pub use openapiv3 as openapi;
|
||||
}
|
||||
|
||||
#[cfg(feature = "auth")]
|
||||
mod auth;
|
||||
#[cfg(feature = "auth")]
|
||||
pub use auth::{
|
||||
AuthHandler,
|
||||
AuthMiddleware,
|
||||
AuthSource,
|
||||
AuthStatus,
|
||||
AuthValidation,
|
||||
StaticAuthHandler
|
||||
};
|
||||
|
||||
#[cfg(feature = "openapi")]
|
||||
mod openapi;
|
||||
#[cfg(feature = "openapi")]
|
||||
|
|
|
@ -21,9 +21,9 @@ use indexmap::IndexMap;
|
|||
use log::error;
|
||||
use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN};
|
||||
use openapiv3::{
|
||||
Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem,
|
||||
APIKeyLocation, Components, MediaType, OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem,
|
||||
Paths, ReferenceOr, ReferenceOr::Item, ReferenceOr::Reference, RequestBody as OARequestBody, Response, Responses, Schema,
|
||||
SchemaKind, Server, StatusCode, Type
|
||||
SchemaKind, SecurityRequirement, SecurityScheme, Server, StatusCode, Type
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::panic::RefUnwindSafe;
|
||||
|
@ -125,16 +125,13 @@ impl OpenapiRouter
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct OpenapiHandler(Result<String, String>);
|
||||
|
||||
// dunno what/why/whatever
|
||||
impl RefUnwindSafe for OpenapiHandler {}
|
||||
struct OpenapiHandler(OpenAPI);
|
||||
|
||||
impl OpenapiHandler
|
||||
{
|
||||
fn new(openapi : &OpenapiRouter) -> Self
|
||||
{
|
||||
Self(serde_json::to_string(&openapi.0).map_err(|e| format!("{}", e)))
|
||||
Self(openapi.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,11 +145,63 @@ impl NewHandler for OpenapiHandler
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "auth")]
|
||||
const SECURITY_NAME : &'static str = "authToken";
|
||||
|
||||
#[cfg(feature = "auth")]
|
||||
fn get_security(state : &mut State) -> (Vec<SecurityRequirement>, IndexMap<String, ReferenceOr<SecurityScheme>>)
|
||||
{
|
||||
use crate::AuthSource;
|
||||
use gotham::state::FromState;
|
||||
|
||||
let source = match AuthSource::try_borrow_from(state) {
|
||||
Some(source) => source,
|
||||
None => return Default::default()
|
||||
};
|
||||
|
||||
let mut security : IndexMap<String, Vec<String>> = Default::default();
|
||||
security.insert(SECURITY_NAME.to_owned(), Vec::new());
|
||||
let security = vec![security];
|
||||
|
||||
let security_scheme = match source {
|
||||
AuthSource::Cookie(name) => SecurityScheme::APIKey {
|
||||
location: APIKeyLocation::Cookie,
|
||||
name: name.to_string()
|
||||
},
|
||||
AuthSource::Header(name) => SecurityScheme::APIKey {
|
||||
location: APIKeyLocation::Header,
|
||||
name: name.to_string()
|
||||
},
|
||||
AuthSource::AuthorizationHeader => SecurityScheme::HTTP {
|
||||
scheme: "bearer".to_owned(),
|
||||
bearer_format: Some("JWT".to_owned())
|
||||
}
|
||||
};
|
||||
|
||||
let mut security_schemes : IndexMap<String, ReferenceOr<SecurityScheme>> = Default::default();
|
||||
security_schemes.insert(SECURITY_NAME.to_owned(), ReferenceOr::Item(security_scheme));
|
||||
|
||||
(security, security_schemes)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "auth"))]
|
||||
fn get_security(state : &mut State) -> (Vec<SecurityRequirement>, IndexMap<String, ReferenceOr<SecurityScheme>>)
|
||||
{
|
||||
Default::default()
|
||||
}
|
||||
|
||||
impl Handler for OpenapiHandler
|
||||
{
|
||||
fn handle(self, state : State) -> Box<HandlerFuture>
|
||||
fn handle(self, mut state : State) -> Box<HandlerFuture>
|
||||
{
|
||||
match self.0 {
|
||||
let mut openapi = self.0;
|
||||
let (security, security_schemes) = get_security(&mut state);
|
||||
openapi.security = security;
|
||||
let mut components = openapi.components.unwrap_or_default();
|
||||
components.security_schemes = security_schemes;
|
||||
openapi.components = Some(components);
|
||||
|
||||
match serde_json::to_string(&openapi) {
|
||||
Ok(body) => {
|
||||
let res = create_response(&state, hyper::StatusCode::OK, APPLICATION_JSON, body);
|
||||
Box::new(ok((state, res)))
|
||||
|
|
|
@ -25,5 +25,6 @@ syn = { version = "1.0.13", features = ["extra-traits", "full"] }
|
|||
|
||||
[features]
|
||||
default = []
|
||||
auth = []
|
||||
database = []
|
||||
openapi = []
|
||||
|
|
|
@ -3,8 +3,10 @@ use proc_macro::TokenStream;
|
|||
use proc_macro2::{Ident, TokenStream as TokenStream2};
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{
|
||||
Attribute,
|
||||
FnArg,
|
||||
ItemFn,
|
||||
PatType,
|
||||
ReturnType,
|
||||
Type,
|
||||
parse_macro_input
|
||||
|
@ -84,6 +86,112 @@ impl Method
|
|||
}
|
||||
}
|
||||
|
||||
enum MethodArgumentType
|
||||
{
|
||||
StateRef,
|
||||
StateMutRef,
|
||||
MethodArg(Type),
|
||||
DatabaseConnection(Type),
|
||||
AuthStatus(Type),
|
||||
AuthStatusRef(Type)
|
||||
}
|
||||
|
||||
impl MethodArgumentType
|
||||
{
|
||||
fn is_method_arg(&self) -> bool
|
||||
{
|
||||
match self {
|
||||
Self::MethodArg(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_database_conn(&self) -> bool
|
||||
{
|
||||
match self {
|
||||
Self::DatabaseConnection(_) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_auth_status(&self) -> bool
|
||||
{
|
||||
match self {
|
||||
Self::AuthStatus(_) | Self::AuthStatusRef(_) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
fn is_auth_status_ref(&self) -> bool
|
||||
{
|
||||
match self {
|
||||
Self::AuthStatusRef(_) => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
fn quote_ty(&self) -> Option<TokenStream2>
|
||||
{
|
||||
match self {
|
||||
Self::MethodArg(ty) => Some(quote!(#ty)),
|
||||
Self::DatabaseConnection(ty) => Some(quote!(#ty)),
|
||||
Self::AuthStatus(ty) => Some(quote!(#ty)),
|
||||
Self::AuthStatusRef(ty) => Some(quote!(#ty)),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MethodArgument
|
||||
{
|
||||
ident : Ident,
|
||||
ty : MethodArgumentType
|
||||
}
|
||||
|
||||
fn interpret_arg_ty(index : usize, attrs : &[Attribute], name : &str, ty : Type) -> MethodArgumentType
|
||||
{
|
||||
let attr = attrs.into_iter()
|
||||
.filter(|arg| arg.path.segments.iter().filter(|path| &path.ident.to_string() == "rest_arg").nth(0).is_some())
|
||||
.nth(0)
|
||||
.map(|arg| arg.tokens.to_string());
|
||||
|
||||
if cfg!(feature = "auth") && (attr.as_deref() == Some("auth") || (attr.is_none() && name == "auth"))
|
||||
{
|
||||
return match ty {
|
||||
Type::Reference(ty) => MethodArgumentType::AuthStatusRef(*ty.elem),
|
||||
ty => MethodArgumentType::AuthStatus(ty)
|
||||
};
|
||||
}
|
||||
|
||||
if cfg!(feature = "database") && (attr.as_deref() == Some("connection") || attr.as_deref() == Some("conn") || (attr.is_none() && name == "conn"))
|
||||
{
|
||||
return MethodArgumentType::DatabaseConnection(match ty {
|
||||
Type::Reference(ty) => *ty.elem,
|
||||
ty => ty
|
||||
});
|
||||
}
|
||||
|
||||
if index == 0
|
||||
{
|
||||
return match ty {
|
||||
Type::Reference(ty) => if ty.mutability.is_none() { MethodArgumentType::StateRef } else { MethodArgumentType::StateMutRef },
|
||||
_ => panic!("The first argument, unless some feature is used, has to be a (mutable) reference to gotham::state::State")
|
||||
};
|
||||
}
|
||||
|
||||
MethodArgumentType::MethodArg(ty)
|
||||
}
|
||||
|
||||
fn interpret_arg(index : usize, arg : &PatType) -> MethodArgument
|
||||
{
|
||||
let pat = &arg.pat;
|
||||
let ident = format_ident!("arg{}", index);
|
||||
let orig_name = quote!(#pat);
|
||||
let ty = interpret_arg_ty(index, &arg.attrs, &orig_name.to_string(), *arg.ty.clone());
|
||||
|
||||
MethodArgument { ident, ty }
|
||||
}
|
||||
|
||||
pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -> TokenStream
|
||||
{
|
||||
let krate = super::krate();
|
||||
|
@ -101,71 +209,90 @@ pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) -
|
|||
ReturnType::Type(_, ty) => (quote!(#ty), false)
|
||||
};
|
||||
|
||||
// extract arguments into pattern, ident and type
|
||||
// some default idents we'll need
|
||||
let state_ident = format_ident!("state");
|
||||
let args : Vec<(usize, TokenStream2, Ident, Type)> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
|
||||
FnArg::Typed(arg) => {
|
||||
let pat = &arg.pat;
|
||||
let ident = format_ident!("arg{}", i);
|
||||
(i, quote!(#pat), ident, *arg.ty.clone())
|
||||
},
|
||||
let repo_ident = format_ident!("repo");
|
||||
let conn_ident = format_ident!("conn");
|
||||
let auth_ident = format_ident!("auth");
|
||||
|
||||
// extract arguments into pattern, ident and type
|
||||
let args : Vec<MethodArgument> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
|
||||
FnArg::Typed(arg) => interpret_arg(i, arg),
|
||||
FnArg::Receiver(_) => panic!("didn't expect self parameter")
|
||||
}).collect();
|
||||
|
||||
// find the database connection if enabled and present
|
||||
let repo_ident = format_ident!("database_repo");
|
||||
let args_conn = if cfg!(feature = "database") {
|
||||
args.iter().filter(|(_, pat, _, _)| pat.to_string() == "conn").nth(0)
|
||||
} else { None };
|
||||
let args_conn_name = args_conn.map(|(_, pat, _, _)| pat.to_string());
|
||||
|
||||
// extract the generic parameters to use
|
||||
let mut generics : Vec<TokenStream2> = args.iter().skip(1)
|
||||
.filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name)
|
||||
.map(|(_, _, _, ty)| quote!(#ty)).collect();
|
||||
let mut generics : Vec<TokenStream2> = args.iter()
|
||||
.filter(|arg| (*arg).ty.is_method_arg())
|
||||
.map(|arg| arg.ty.quote_ty().unwrap())
|
||||
.collect();
|
||||
generics.push(quote!(#ret));
|
||||
|
||||
// extract the definition of our method
|
||||
let mut args_def : Vec<TokenStream2> = args.iter()
|
||||
.filter(|(_, pat, _, _)| Some(pat.to_string()) != args_conn_name)
|
||||
.map(|(i, _, ident, ty)| if *i == 0 { quote!(#state_ident : #ty) } else { quote!(#ident : #ty) }).collect();
|
||||
if let Some(_) = args_conn
|
||||
{
|
||||
args_def.insert(0, quote!(#state_ident : &mut #krate::export::State));
|
||||
}
|
||||
.filter(|arg| (*arg).ty.is_method_arg())
|
||||
.map(|arg| {
|
||||
let ident = &arg.ident;
|
||||
let ty = arg.ty.quote_ty();
|
||||
quote!(#ident : #ty)
|
||||
}).collect();
|
||||
args_def.insert(0, quote!(#state_ident : &mut #krate::export::State));
|
||||
|
||||
// extract the arguments to pass over to the supplied method
|
||||
let args_pass : Vec<TokenStream2> = args.iter().map(|(i, pat, ident, _)| if Some(pat.to_string()) != args_conn_name {
|
||||
if *i == 0 { quote!(#state_ident) } else { quote!(#ident) }
|
||||
} else {
|
||||
quote!(&#ident)
|
||||
let args_pass : Vec<TokenStream2> = args.iter().map(|arg| match (&arg.ty, &arg.ident) {
|
||||
(MethodArgumentType::StateRef, _) => quote!(#state_ident),
|
||||
(MethodArgumentType::StateMutRef, _) => quote!(#state_ident),
|
||||
(MethodArgumentType::MethodArg(_), ident) => quote!(#ident),
|
||||
(MethodArgumentType::DatabaseConnection(_), _) => quote!(&#conn_ident),
|
||||
(MethodArgumentType::AuthStatus(_), _) => quote!(#auth_ident.clone()),
|
||||
(MethodArgumentType::AuthStatusRef(_), _) => quote!(#auth_ident)
|
||||
}).collect();
|
||||
|
||||
// prepare the method block
|
||||
let mut block = if is_no_content { quote!(#fun_ident(#(#args_pass),*); Default::default()) } else { quote!(#fun_ident(#(#args_pass),*)) };
|
||||
if /*cfg!(feature = "database") &&*/ let Some((_, _, conn_ident, conn_ty)) = args_conn // https://github.com/rust-lang/rust/issues/53667
|
||||
let mut block = quote!(#fun_ident(#(#args_pass),*));
|
||||
if is_no_content
|
||||
{
|
||||
let conn_ty_real = match conn_ty {
|
||||
Type::Reference(ty) => &*ty.elem,
|
||||
ty => ty
|
||||
};
|
||||
block = quote!(#block; Default::default())
|
||||
}
|
||||
if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_database_conn()).nth(0)
|
||||
{
|
||||
let conn_ty = arg.ty.quote_ty();
|
||||
block = quote! {
|
||||
use #krate::export::{Future, FromState};
|
||||
let #repo_ident = <#krate::export::Repo<#conn_ty_real>>::borrow_from(&#state_ident).clone();
|
||||
let #repo_ident = <#krate::export::Repo<#conn_ty>>::borrow_from(&#state_ident).clone();
|
||||
#repo_ident.run::<_, #ret, ()>(move |#conn_ident| {
|
||||
Ok({#block})
|
||||
}).wait().unwrap()
|
||||
};
|
||||
}
|
||||
if let Some(arg) = args.iter().filter(|arg| (*arg).ty.is_auth_status()).nth(0)
|
||||
{
|
||||
let auth_ty = arg.ty.quote_ty();
|
||||
block = quote! {
|
||||
let #auth_ident : &#auth_ty = <#auth_ty>::borrow_from(#state_ident);
|
||||
#block
|
||||
};
|
||||
}
|
||||
|
||||
// prepare the where clause
|
||||
let mut where_clause = quote!(#resource_ident : #krate::Resource,);
|
||||
for arg in args.iter().filter(|arg| (*arg).ty.is_auth_status() && !(*arg).ty.is_auth_status_ref())
|
||||
{
|
||||
let auth_ty = arg.ty.quote_ty();
|
||||
where_clause = quote!(#where_clause #auth_ty : Clone,);
|
||||
}
|
||||
|
||||
// put everything together
|
||||
let output = quote! {
|
||||
#fun
|
||||
|
||||
impl #krate::#trait_ident<#(#generics),*> for #resource_ident
|
||||
where #resource_ident : #krate::Resource
|
||||
where #where_clause
|
||||
{
|
||||
fn #method_ident(#(#args_def),*) -> #ret
|
||||
{
|
||||
#[allow(unused_imports)]
|
||||
use #krate::export::{Future, FromState};
|
||||
|
||||
#block
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue