// SPDX-License-Identifier: MIT OR Apache-2.0 // Copyright 2020, Alyssa Ross use proc_macro; use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use std::convert::{TryFrom, TryInto}; use std::fmt::{self, Display, Formatter}; use syn::spanned::Spanned; use syn::{ parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, DeriveInput, GenericParam, Generics, Lifetime, LifetimeDef, Lit, Meta, MetaList, MetaNameValue, NestedMeta, }; #[proc_macro_derive(SerializeWithFds, attributes(msg_socket2))] pub fn serialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let impl_for_input = impl_serialize_with_fds(input); // println!("{}", impl_for_input); impl_for_input.into() } #[proc_macro_derive(DeserializeWithFds, attributes(msg_socket2))] pub fn deserialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let impl_for_input = impl_deserialize_with_fds(input); // println!("{}", impl_for_input); impl_for_input.into() } #[derive(Debug)] enum Strategy { AsRawFd, Serde, } impl TryFrom for Strategy { type Error = Error; fn try_from(lit: Lit) -> Result { match lit { Lit::Str(value) if value.value() == "AsRawFd" => Ok(Strategy::AsRawFd), Lit::Str(value) if value.value() == "serde" => Ok(Strategy::Serde), _ => Err(Error::Str { message: "invalid strategy", span: Some(lit.span()), }), } } } #[derive(Debug)] struct Config { strategy: Strategy, } #[derive(Debug)] enum Error { Syn(syn::Error), Str { message: &'static str, span: Option, }, } impl Display for Error { fn fmt(&self, f: &mut Formatter) -> fmt::Result { use Error::*; match self { Syn(error) => write!(f, "{}", error), Str { message, .. } => write!(f, "{}", message), } } } impl Spanned for Error { fn span(&self) -> Span { use Error::*; match self { Syn(error) => error.span(), Str { span, .. } => span.unwrap_or(Span::call_site()), } } } impl From for Error { fn from(error: syn::Error) -> Self { Self::Syn(error) } } fn get_strategy(attrs: &[Attribute]) -> Result, Error> { for attr in attrs { let (attr_path, nested) = match attr.parse_meta()? { Meta::List(MetaList { path, nested, .. }) => (path, nested), _ => continue, }; match attr_path.get_ident() { Some(ident) if ident == "msg_socket2" => (), _ => continue, } for meta in nested { let (name_path, value) = match meta { NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => (path, lit), _ => continue, }; match name_path.get_ident() { Some(ident) if ident == "strategy" => (), _ => continue, } return value.try_into().map(Some); } } Ok(None) } fn get_config(attrs: &[Attribute]) -> Result { let strategy = get_strategy(&attrs)?.ok_or_else(|| Error::Str { message: "missing strategy", span: None, })?; Ok(Config { strategy }) } /// Like `syn::Generics::split_for_impl`, but a list of lifetimes can /// be added to the start of the generics list to go next to the /// `impl` keyword, but not the generics list to go after the type. /// /// This allows for generics that are required for a trait lifetime, /// rather than for the type its being implemented for. fn split_for_impl_with_trait_lifetimes( lifetimes: &'static [&'static str], mut generics: Generics, ) -> (TokenStream, TokenStream, TokenStream) { let (_, ty_generics, where_clause) = generics.split_for_impl(); // Convert to token streams to take ownership, so generics can be // mutated to generate impl_generics. let ty_generics = ty_generics.into_token_stream(); let where_clause = where_clause.into_token_stream(); // Add the trait lifetimes to the list of generics, then calculate // the impl_generics. Can't do this before calculating // ty_generics because then the trait lifetimes would be passed as // lifetime parameters to the type. // // FIXME: there must be a better way of doing this. let defs: Punctuated<_, Comma> = lifetimes .iter() .map(|name| Lifetime::new(name, Span::call_site())) .map(LifetimeDef::new) .map(GenericParam::Lifetime) .chain(generics.params.into_iter()) .collect(); generics.params = defs; let (impl_generics, _, _) = generics.split_for_impl(); (impl_generics.into_token_stream(), ty_generics, where_clause) } fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); quote! { impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause { fn serialize(&self, serializer: S) -> ::std::result::Result where S: ::msg_socket2::Serializer, { serializer.serialize_unit() } fn serialize_fds<'__fds, S: ::msg_socket2::FdSerializer<'__fds>>( &'__fds self, serializer: S, ) -> ::std::result::Result { serializer.serialize_raw_fd(self) } } } } fn as_raw_fd_deserialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; let (impl_generics, ty_generics, where_clause) = split_for_impl_with_trait_lifetimes(&["'__de"], input.generics); quote! { impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics #where_clause { fn deserialize(deserializer: D) -> ::std::result::Result where D: ::msg_socket2::DeserializerWithFds<'__de>, { deserializer.deserialize_fd().map(|x| x.specialize()) } } } } fn serde_serialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); quote! { impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause { fn serialize(&self, serializer: S) -> ::std::result::Result where S: ::msg_socket2::Serializer, { ::msg_socket2::Serialize::serialize(&self, serializer) } fn serialize_fds<'__fds, S: ::msg_socket2::FdSerializer<'__fds>>( &'__fds self, serializer: S, ) -> ::std::result::Result { serializer.serialize_unit() } } } } fn serde_deserialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; let (impl_generics, ty_generics, where_clause) = split_for_impl_with_trait_lifetimes(&["'__de"], input.generics); quote! { impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics #where_clause { fn deserialize(deserializer: D) -> ::std::result::Result where D: ::msg_socket2::DeserializerWithFds<'__de>, { deserializer.invite_deserialize() } } } } fn impl_serialize_with_fds(input: DeriveInput) -> TokenStream { let config = match get_config(&input.attrs) { Ok(config) => config, Err(error) => { // FIXME: Tried using compile_error! but it just got // swallowed. Is that not possible with derive macros? panic!("{}", error); } }; match config.strategy { Strategy::AsRawFd => as_raw_fd_serialize_impl(input), Strategy::Serde => serde_serialize_impl(input), } } fn impl_deserialize_with_fds(input: DeriveInput) -> TokenStream { let config = match get_config(&input.attrs) { Ok(config) => config, Err(error) => { // FIXME: Tried using compile_error! but it just got // swallowed. Is that not possible with derive macros? panic!("{}", error); } }; match config.strategy { Strategy::AsRawFd => as_raw_fd_deserialize_impl(input), Strategy::Serde => serde_deserialize_impl(input), } }