From 876da721f857e7ed8ccecbf1c2ca90bec6c38126 Mon Sep 17 00:00:00 2001 From: Dominic Date: Tue, 14 Jan 2020 23:16:31 +0100 Subject: [PATCH] some fixes for database support --- gotham_restful/src/lib.rs | 4 +++ gotham_restful_derive/src/method.rs | 53 +++++++++++++++++++++-------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/gotham_restful/src/lib.rs b/gotham_restful/src/lib.rs index d5dc50c..4384b5c 100644 --- a/gotham_restful/src/lib.rs +++ b/gotham_restful/src/lib.rs @@ -122,8 +122,12 @@ pub use gotham_restful_derive::*; #[doc(hidden)] pub mod export { + pub use futures::future::Future; + pub use gotham::state::FromState; + #[cfg(feature = "database")] pub use gotham_middleware_diesel::Repo; + #[cfg(feature = "openapi")] pub use indexmap::IndexMap; #[cfg(feature = "openapi")] diff --git a/gotham_restful_derive/src/method.rs b/gotham_restful_derive/src/method.rs index 6a42f6c..09c1dd0 100644 --- a/gotham_restful_derive/src/method.rs +++ b/gotham_restful_derive/src/method.rs @@ -6,6 +6,8 @@ use syn::{ FnArg, ItemFn, ReturnType, + Type, + TypePath, parse_macro_input }; use std::str::FromStr; @@ -99,33 +101,55 @@ pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) - ReturnType::Default => (quote!(#krate::NoContent), true), ReturnType::Type(_, ty) => (quote!(#ty), false) }; - let args : Vec<(TokenStream2, TokenStream2)> = fun.sig.inputs.iter().map(|arg| match arg { + + // extract arguments into pattern, ident and type + let state_ident = format_ident!("state"); + let args : Vec<(TokenStream2, Ident, Type)> = fun.sig.inputs.iter().enumerate().map(|(i, arg)| match arg { FnArg::Typed(arg) => { let pat = &arg.pat; - let ty = &arg.ty; - (quote!(#pat), quote!(#ty)) + let ident = if i == 0 { state_ident.clone() } else { format_ident!("arg{}", i-1) }; + (quote!(#pat), ident, *arg.ty.clone()) }, FnArg::Receiver(_) => panic!("didn't expect self parameter") }).collect(); - let args_state = args.iter().map(|(pat, _)| pat).nth(0).expect("state parameter is required"); + + // find the database connection if enabled and present + let repo_ident = format_ident!("database_repo"); let args_conn = if cfg!(feature = "database") { - args.iter().filter(|(pat, _)| pat.to_string() == "conn").nth(0) + args.iter().filter(|(pat, _, _)| pat.to_string() == "conn").nth(0) } else { None }; + let args_conn_name = args_conn.map(|(pat, _, _)| pat.to_string()); + + // extract the generic parameters to use let mut generics : Vec = args.iter().skip(1) - .filter(|(pat, _)| Some(pat.to_string()) != args_conn.map(|(pat, _)| pat.to_string())) - .map(|(_, ty)| quote!(#ty)).collect(); + .filter(|(pat, _, _)| Some(pat.to_string()) != args_conn_name) + .map(|(_, _, ty)| quote!(#ty)).collect(); generics.push(quote!(#ret)); + + // extract the definition of our method let args_def : Vec = args.iter() - .filter(|(pat, _)| Some(pat.to_string()) != args_conn.map(|(pat, _)| pat.to_string())) - .map(|(pat, ty)| quote!(#pat : #ty)).collect(); - let args_pass : Vec = args.iter().map(|(pat, _)| quote!(#pat)).collect(); + .filter(|(pat, _, _)| Some(pat.to_string()) != args_conn_name) + .map(|(_, ident, ty)| quote!(#ident : #ty)).collect(); + + // extract the arguments to pass over to the supplied method + let args_pass : Vec = args.iter().map(|(pat, ident, _)| if Some(pat.to_string()) != args_conn_name { + quote!(#ident) + } else { + quote!(&#ident) + }).collect(); + + // prepare the method block let mut block = if is_no_content { quote!(#fun_ident(#(#args_pass),*); Default::default()) } else { quote!(#fun_ident(#(#args_pass),*)) }; - if /*cfg!(feature = "database") &&*/ let Some((conn_pat, conn_ty)) = args_conn // https://github.com/rust-lang/rust/issues/53667 + if /*cfg!(feature = "database") &&*/ let Some((_, conn_ident, conn_ty)) = args_conn // https://github.com/rust-lang/rust/issues/53667 { - let repo_ident = format_ident!("{}_database_repo", conn_pat.to_string()); + let conn_ty_real = match conn_ty { + Type::Reference(ty) => &*ty.elem, + ty => ty + }; block = quote! { - let #repo_ident = <#krate::export::Repo<#conn_ty>>::borrow_from(&#args_state).clone(); - #repo_ident.run(move |#conn_pat| { + use #krate::export::{Future, FromState}; + let #repo_ident = <#krate::export::Repo<#conn_ty_real>>::borrow_from(&#state_ident).clone(); + #repo_ident.run(move |#conn_ident| { #block }).wait() }; @@ -149,5 +173,6 @@ pub fn expand_method(method : Method, attrs : TokenStream, item : TokenStream) - route.#method_ident::<#resource_ident, #(#generics),*>(); } }; + println!("{}", output); output.into() }