diff options
author | Zach Reizner <zachr@google.com> | 2020-03-25 01:36:46 -0700 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2020-05-06 02:15:12 +0000 |
commit | 882e2cea3bdeb6341b1e38b04e93ac6ede5a493d (patch) | |
tree | 9d043a8ea945808aeead878561bb7891860ca8eb /msg_socket | |
parent | 8b3ee41b302d5d99da917c3b21037135d1665b75 (diff) | |
download | crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.gz crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.bz2 crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.lz crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.xz crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.zst crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.zip |
msg_socket: impl skip helper attribute
Fields with a default value can be skipped using the `#[msg_on_socket(skip)]` attribute. TEST=cargo test -p msg_socket BUG=None Change-Id: I9fea33e641a7da62b7864ba1847e884b32502491 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2168587 Reviewed-by: Dylan Reid <dgreid@chromium.org> Tested-by: kokoro <noreply+kokoro@google.com> Tested-by: Zach Reizner <zachr@chromium.org> Commit-Queue: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'msg_socket')
-rw-r--r-- | msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs | 150 |
1 files changed, 124 insertions, 26 deletions
diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs index a3c065c..6db0417 100644 --- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs +++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs @@ -10,11 +10,12 @@ use std::vec::Vec; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Type, + parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta, + NestedMeta, Type, }; /// The function that derives the recursive implementation for struct or enum. -#[proc_macro_derive(MsgOnSocket)] +#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))] pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let impl_for_input = socket_msg_impl(input); @@ -50,6 +51,13 @@ fn is_named_struct(ds: &DataStruct) -> bool { } /************************** Named Struct Impls ********************************************/ + +struct StructField { + member: Member, + ty: Type, + skipped: bool, +} + fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { let fields = get_struct_fields(ds); let uses_fd_impl = define_uses_fd_for_struct(&fields); @@ -68,7 +76,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { } // Flatten struct fields. -fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> { +fn get_struct_fields(ds: DataStruct) -> Vec<StructField> { let fields = match ds.fields { Fields::Named(fields_named) => fields_named.named, _ => { @@ -82,17 +90,48 @@ fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> { None => panic!("Unknown Error."), }; let ty = field.ty; - vec.push((member, ty)); + let mut skipped = false; + for attr in field + .attrs + .iter() + .filter(|attr| attr.path.is_ident("msg_on_socket")) + { + match attr.parse_meta().unwrap() { + Meta::List(meta) => { + for nested in meta.nested { + match nested { + NestedMeta::Meta(Meta::Path(meta_path)) + if meta_path.is_ident("skip") => + { + skipped = true; + } + _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }), + } + } + } + _ => panic!("unrecognized attribute `{}`", quote! { #attr }), + } + } + vec.push(StructField { + member, + ty, + skipped, + }); } vec } -fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream { - if fields.len() == 0 { +fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream { + let field_types: Vec<_> = fields + .iter() + .filter(|f| !f.skipped) + .map(|f| &f.ty) + .collect(); + + if field_types.is_empty() { return quote!(); } - let field_types = fields.iter().map(|(_, ty)| ty); quote! { fn uses_fd() -> bool { #(<#field_types>::uses_fd())||* @@ -100,7 +139,7 @@ fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream { +fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream { let (msg_size, fd_count) = get_fields_buffer_size_sum(fields); quote! { fn msg_size(&self) -> usize { @@ -112,17 +151,24 @@ fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { let mut read_fields = Vec::new(); let mut init_fields = Vec::new(); - for (field_member, field_ty) in fields { - let ident = match field_member { + for field in fields { + let ident = match &field.member { Member::Named(ident) => ident, Member::Unnamed(_) => unreachable!(), }; - let read_field = read_from_buffer_and_move_offset(&ident, &field_ty); - read_fields.push(read_field); let name = ident.clone(); + if field.skipped { + let ty = &field.ty; + init_fields.push(quote! { + #name: <#ty>::default() + }); + continue; + } + let read_field = read_from_buffer_and_move_offset(&ident, &field.ty); + read_fields.push(read_field); init_fields.push(quote!(#name)); } quote! { @@ -143,10 +189,13 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> To } } -fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { let mut write_fields = Vec::new(); - for (field_member, _) in fields { - let ident = match field_member { + for field in fields { + if field.skipped { + continue; + } + let ident = match &field.member { Member::Named(ident) => ident, Member::Unnamed(_) => unreachable!(), }; @@ -438,7 +487,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream { } } -fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> { +fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> { let mut field_idents = Vec::new(); let fields = match ds.fields { Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed, @@ -449,17 +498,21 @@ fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> { for (idx, field) in fields.iter().enumerate() { let member = Member::Unnamed(Index::from(idx)); let ty = field.ty.clone(); - field_idents.push((member, ty)); + field_idents.push(StructField { + member, + ty, + skipped: false, + }); } field_idents } -fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream { +fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream { if fields.len() == 0 { return quote!(); } - let field_types = fields.iter().map(|(_, ty)| ty); + let field_types = fields.iter().map(|f| &f.ty); quote! { fn uses_fd() -> bool { #(<#field_types>::uses_fd())||* @@ -467,13 +520,13 @@ fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { let mut read_fields = Vec::new(); let mut init_fields = Vec::new(); - for (idx, (_, field_ty)) in fields.iter().enumerate() { + for (idx, field) in fields.iter().enumerate() { let tmp_name = format!("tuple_tmp{}", idx); let tmp_name = Ident::new(&tmp_name, Span::call_site()); - let read_field = read_from_buffer_and_move_offset(&tmp_name, field_ty); + let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty); read_fields.push(read_field); init_fields.push(quote!(#tmp_name)); } @@ -496,7 +549,7 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> Tok } } -fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { let mut write_fields = Vec::new(); let mut tmp_names = Vec::new(); for idx in 0..fields.len() { @@ -520,8 +573,12 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> To } } /************************** Helpers ********************************************/ -fn get_fields_buffer_size_sum(fields: &[(Member, Type)]) -> (TokenStream, TokenStream) { - let fields: Vec<_> = fields.iter().map(|(m, _)| m).collect(); +fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) { + let fields: Vec<_> = fields + .iter() + .filter(|f| !f.skipped) + .map(|f| &f.member) + .collect(); if fields.len() > 0 { ( quote! { @@ -808,4 +865,45 @@ mod tests { assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); } + + #[test] + fn end_to_end_struct_skip_test() { + let input: DeriveInput = parse_quote! { + struct MyMsg { + #[msg_on_socket(skip)] + a: u8, + } + }; + + let expected = quote! { + impl msg_socket::MsgOnSocket for MyMsg { + fn msg_size(&self) -> usize { + 0 + } + fn fd_count(&self) -> usize { + 0 + } + unsafe fn read_from_buffer( + buffer: &[u8], + fds: &[std::os::unix::io::RawFd], + ) -> msg_socket::MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + Ok((Self { a: <u8>::default() }, __fd_offset)) + } + fn write_to_buffer( + &self, + buffer: &mut [u8], + fds: &mut [std::os::unix::io::RawFd], + ) -> msg_socket::MsgResult<usize> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + Ok(__fd_offset) + } + } + + }; + + assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); + } } |