summary refs log tree commit diff
path: root/msg_socket
diff options
context:
space:
mode:
authorZach Reizner <zachr@google.com>2020-03-25 01:36:46 -0700
committerCommit Bot <commit-bot@chromium.org>2020-05-06 02:15:12 +0000
commit882e2cea3bdeb6341b1e38b04e93ac6ede5a493d (patch)
tree9d043a8ea945808aeead878561bb7891860ca8eb /msg_socket
parent8b3ee41b302d5d99da917c3b21037135d1665b75 (diff)
downloadcrosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.gz
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.bz2
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.lz
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.xz
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.tar.zst
crosvm-882e2cea3bdeb6341b1e38b04e93ac6ede5a493d.zip
msg_socket: impl skip helper attribute
Fields with a default value can be skipped using the
`#[msg_on_socket(skip)]` attribute.

TEST=cargo test -p msg_socket
BUG=None

Change-Id: I9fea33e641a7da62b7864ba1847e884b32502491
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2168587
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Tested-by: Zach Reizner <zachr@chromium.org>
Commit-Queue: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'msg_socket')
-rw-r--r--msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs150
1 files changed, 124 insertions, 26 deletions
diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
index a3c065c..6db0417 100644
--- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
+++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
@@ -10,11 +10,12 @@ use std::vec::Vec;
 use proc_macro2::{Span, TokenStream};
 use quote::{format_ident, quote};
 use syn::{
-    parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Type,
+    parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta,
+    NestedMeta, Type,
 };
 
 /// The function that derives the recursive implementation for struct or enum.
-#[proc_macro_derive(MsgOnSocket)]
+#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))]
 pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let input = parse_macro_input!(input as DeriveInput);
     let impl_for_input = socket_msg_impl(input);
@@ -50,6 +51,13 @@ fn is_named_struct(ds: &DataStruct) -> bool {
 }
 
 /************************** Named Struct Impls ********************************************/
+
+struct StructField {
+    member: Member,
+    ty: Type,
+    skipped: bool,
+}
+
 fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
     let fields = get_struct_fields(ds);
     let uses_fd_impl = define_uses_fd_for_struct(&fields);
@@ -68,7 +76,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
 }
 
 // Flatten struct fields.
-fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> {
+fn get_struct_fields(ds: DataStruct) -> Vec<StructField> {
     let fields = match ds.fields {
         Fields::Named(fields_named) => fields_named.named,
         _ => {
@@ -82,17 +90,48 @@ fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> {
             None => panic!("Unknown Error."),
         };
         let ty = field.ty;
-        vec.push((member, ty));
+        let mut skipped = false;
+        for attr in field
+            .attrs
+            .iter()
+            .filter(|attr| attr.path.is_ident("msg_on_socket"))
+        {
+            match attr.parse_meta().unwrap() {
+                Meta::List(meta) => {
+                    for nested in meta.nested {
+                        match nested {
+                            NestedMeta::Meta(Meta::Path(meta_path))
+                                if meta_path.is_ident("skip") =>
+                            {
+                                skipped = true;
+                            }
+                            _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }),
+                        }
+                    }
+                }
+                _ => panic!("unrecognized attribute `{}`", quote! { #attr }),
+            }
+        }
+        vec.push(StructField {
+            member,
+            ty,
+            skipped,
+        });
     }
     vec
 }
 
-fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream {
-    if fields.len() == 0 {
+fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream {
+    let field_types: Vec<_> = fields
+        .iter()
+        .filter(|f| !f.skipped)
+        .map(|f| &f.ty)
+        .collect();
+
+    if field_types.is_empty() {
         return quote!();
     }
 
-    let field_types = fields.iter().map(|(_, ty)| ty);
     quote! {
         fn uses_fd() -> bool {
             #(<#field_types>::uses_fd())||*
@@ -100,7 +139,7 @@ fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream {
+fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream {
     let (msg_size, fd_count) = get_fields_buffer_size_sum(fields);
     quote! {
         fn msg_size(&self) -> usize {
@@ -112,17 +151,24 @@ fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut read_fields = Vec::new();
     let mut init_fields = Vec::new();
-    for (field_member, field_ty) in fields {
-        let ident = match field_member {
+    for field in fields {
+        let ident = match &field.member {
             Member::Named(ident) => ident,
             Member::Unnamed(_) => unreachable!(),
         };
-        let read_field = read_from_buffer_and_move_offset(&ident, &field_ty);
-        read_fields.push(read_field);
         let name = ident.clone();
+        if field.skipped {
+            let ty = &field.ty;
+            init_fields.push(quote! {
+                #name: <#ty>::default()
+            });
+            continue;
+        }
+        let read_field = read_from_buffer_and_move_offset(&ident, &field.ty);
+        read_fields.push(read_field);
         init_fields.push(quote!(#name));
     }
     quote! {
@@ -143,10 +189,13 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> To
     }
 }
 
-fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut write_fields = Vec::new();
-    for (field_member, _) in fields {
-        let ident = match field_member {
+    for field in fields {
+        if field.skipped {
+            continue;
+        }
+        let ident = match &field.member {
             Member::Named(ident) => ident,
             Member::Unnamed(_) => unreachable!(),
         };
@@ -438,7 +487,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream {
     }
 }
 
-fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> {
+fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> {
     let mut field_idents = Vec::new();
     let fields = match ds.fields {
         Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
@@ -449,17 +498,21 @@ fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> {
     for (idx, field) in fields.iter().enumerate() {
         let member = Member::Unnamed(Index::from(idx));
         let ty = field.ty.clone();
-        field_idents.push((member, ty));
+        field_idents.push(StructField {
+            member,
+            ty,
+            skipped: false,
+        });
     }
     field_idents
 }
 
-fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream {
+fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream {
     if fields.len() == 0 {
         return quote!();
     }
 
-    let field_types = fields.iter().map(|(_, ty)| ty);
+    let field_types = fields.iter().map(|f| &f.ty);
     quote! {
         fn uses_fd() -> bool {
             #(<#field_types>::uses_fd())||*
@@ -467,13 +520,13 @@ fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut read_fields = Vec::new();
     let mut init_fields = Vec::new();
-    for (idx, (_, field_ty)) in fields.iter().enumerate() {
+    for (idx, field) in fields.iter().enumerate() {
         let tmp_name = format!("tuple_tmp{}", idx);
         let tmp_name = Ident::new(&tmp_name, Span::call_site());
-        let read_field = read_from_buffer_and_move_offset(&tmp_name, field_ty);
+        let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty);
         read_fields.push(read_field);
         init_fields.push(quote!(#tmp_name));
     }
@@ -496,7 +549,7 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> Tok
     }
 }
 
-fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut write_fields = Vec::new();
     let mut tmp_names = Vec::new();
     for idx in 0..fields.len() {
@@ -520,8 +573,12 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> To
     }
 }
 /************************** Helpers ********************************************/
-fn get_fields_buffer_size_sum(fields: &[(Member, Type)]) -> (TokenStream, TokenStream) {
-    let fields: Vec<_> = fields.iter().map(|(m, _)| m).collect();
+fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) {
+    let fields: Vec<_> = fields
+        .iter()
+        .filter(|f| !f.skipped)
+        .map(|f| &f.member)
+        .collect();
     if fields.len() > 0 {
         (
             quote! {
@@ -808,4 +865,45 @@ mod tests {
 
         assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
     }
+
+    #[test]
+    fn end_to_end_struct_skip_test() {
+        let input: DeriveInput = parse_quote! {
+            struct MyMsg {
+                #[msg_on_socket(skip)]
+                a: u8,
+            }
+        };
+
+        let expected = quote! {
+            impl msg_socket::MsgOnSocket for MyMsg {
+                fn msg_size(&self) -> usize {
+                    0
+                }
+                fn fd_count(&self) -> usize {
+                    0
+                }
+                unsafe fn read_from_buffer(
+                    buffer: &[u8],
+                    fds: &[std::os::unix::io::RawFd],
+                ) -> msg_socket::MsgResult<(Self, usize)> {
+                    let mut __offset = 0usize;
+                    let mut __fd_offset = 0usize;
+                    Ok((Self { a: <u8>::default() }, __fd_offset))
+                }
+                fn write_to_buffer(
+                    &self,
+                    buffer: &mut [u8],
+                    fds: &mut [std::os::unix::io::RawFd],
+                ) -> msg_socket::MsgResult<usize> {
+                    let mut __offset = 0usize;
+                    let mut __fd_offset = 0usize;
+                    Ok(__fd_offset)
+                }
+            }
+
+        };
+
+        assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
+    }
 }