summary refs log tree commit diff
path: root/msg_socket
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-05-08 15:27:56 +0000
committerAlyssa Ross <hi@alyssa.is>2020-05-10 02:39:28 +0000
commit2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b (patch)
treefefaf2c13796f8f2fa9a13b99b09c3b40ab5966b /msg_socket
parent00c41c28bbc44b37fc8dcf5d2a5b4679f2aa4297 (diff)
parent03a54abf852984f696e7a101ff9590f05ebcba5b (diff)
downloadcrosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.gz
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.bz2
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.lz
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.xz
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.zst
crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.zip
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'msg_socket')
-rw-r--r--msg_socket/Cargo.toml1
-rw-r--r--msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs168
-rw-r--r--msg_socket/src/lib.rs22
-rw-r--r--msg_socket/src/msg_on_socket.rs51
4 files changed, 200 insertions, 42 deletions
diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml
index c803bed..80eba0b 100644
--- a/msg_socket/Cargo.toml
+++ b/msg_socket/Cargo.toml
@@ -11,3 +11,4 @@ futures = "*"
 libc = "*"
 msg_on_socket_derive = { path = "msg_on_socket_derive" }
 sys_util = { path = "../sys_util"  }
+sync = { path = "../sync"  }
diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
index 7b16546..a5f31f8 100644
--- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
+++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
@@ -9,11 +9,12 @@ use std::vec::Vec;
 use proc_macro2::{Span, TokenStream};
 use quote::{format_ident, quote};
 use syn::{
-    parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Type,
+    parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta,
+    NestedMeta, Type,
 };
 
 /// The function that derives the recursive implementation for struct or enum.
-#[proc_macro_derive(MsgOnSocket)]
+#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))]
 pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let input = parse_macro_input!(input as DeriveInput);
     let impl_for_input = msg_socket_impl(input);
@@ -46,6 +47,13 @@ fn is_named_struct(ds: &DataStruct) -> bool {
 }
 
 /************************** Named Struct Impls ********************************************/
+
+struct StructField {
+    member: Member,
+    ty: Type,
+    skipped: bool,
+}
+
 fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
     let fields = get_struct_fields(ds);
     let uses_fd_impl = define_uses_fd_for_struct(&fields);
@@ -64,7 +72,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
 }
 
 // Flatten struct fields.
-fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> {
+fn get_struct_fields(ds: DataStruct) -> Vec<StructField> {
     let fields = match ds.fields {
         Fields::Named(fields_named) => fields_named.named,
         _ => {
@@ -78,17 +86,48 @@ fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> {
             None => panic!("Unknown Error."),
         };
         let ty = field.ty;
-        vec.push((member, ty));
+        let mut skipped = false;
+        for attr in field
+            .attrs
+            .iter()
+            .filter(|attr| attr.path.is_ident("msg_on_socket"))
+        {
+            match attr.parse_meta().unwrap() {
+                Meta::List(meta) => {
+                    for nested in meta.nested {
+                        match nested {
+                            NestedMeta::Meta(Meta::Path(meta_path))
+                                if meta_path.is_ident("skip") =>
+                            {
+                                skipped = true;
+                            }
+                            _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }),
+                        }
+                    }
+                }
+                _ => panic!("unrecognized attribute `{}`", quote! { #attr }),
+            }
+        }
+        vec.push(StructField {
+            member,
+            ty,
+            skipped,
+        });
     }
     vec
 }
 
-fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream {
-    if fields.len() == 0 {
+fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream {
+    let field_types: Vec<_> = fields
+        .iter()
+        .filter(|f| !f.skipped)
+        .map(|f| &f.ty)
+        .collect();
+
+    if field_types.is_empty() {
         return quote!();
     }
 
-    let field_types = fields.iter().map(|(_, ty)| ty);
     quote! {
         fn uses_fd() -> bool {
             #(<#field_types>::uses_fd())||*
@@ -96,7 +135,7 @@ fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream {
+fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream {
     let (msg_size, fd_count) = get_fields_buffer_size_sum(fields);
     quote! {
         fn msg_size(&self) -> usize {
@@ -108,17 +147,24 @@ fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut read_fields = Vec::new();
     let mut init_fields = Vec::new();
-    for (field_member, field_ty) in fields {
-        let ident = match field_member {
+    for field in fields {
+        let ident = match &field.member {
             Member::Named(ident) => ident,
             Member::Unnamed(_) => unreachable!(),
         };
-        let read_field = read_from_buffer_and_move_offset(&ident, &field_ty);
-        read_fields.push(read_field);
         let name = ident.clone();
+        if field.skipped {
+            let ty = &field.ty;
+            init_fields.push(quote! {
+                #name: <#ty>::default()
+            });
+            continue;
+        }
+        let read_field = read_from_buffer_and_move_offset(&ident, &field.ty);
+        read_fields.push(read_field);
         init_fields.push(quote!(#name));
     }
     quote! {
@@ -139,10 +185,13 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> To
     }
 }
 
-fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut write_fields = Vec::new();
-    for (field_member, _) in fields {
-        let ident = match field_member {
+    for field in fields {
+        if field.skipped {
+            continue;
+        }
+        let ident = match &field.member {
             Member::Named(ident) => ident,
             Member::Unnamed(_) => unreachable!(),
         };
@@ -187,17 +236,13 @@ fn define_uses_fd_for_enum(de: &DataEnum) -> TokenStream {
         }
     }
 
-    if variant_field_types.is_empty() {
-        quote! {
-            fn uses_fd() -> bool {
-                false
-            }
-        }
-    } else {
-        quote! {
-            fn uses_fd() -> bool {
-                #(<#variant_field_types>::uses_fd())||*
-            }
+    if variant_field_types.len() == 0 {
+        return quote!();
+    }
+
+    quote! {
+        fn uses_fd() -> bool {
+            #(<#variant_field_types>::uses_fd())||*
         }
     }
 }
@@ -438,7 +483,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream {
     }
 }
 
-fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> {
+fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> {
     let mut field_idents = Vec::new();
     let fields = match ds.fields {
         Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
@@ -449,17 +494,21 @@ fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> {
     for (idx, field) in fields.iter().enumerate() {
         let member = Member::Unnamed(Index::from(idx));
         let ty = field.ty.clone();
-        field_idents.push((member, ty));
+        field_idents.push(StructField {
+            member,
+            ty,
+            skipped: false,
+        });
     }
     field_idents
 }
 
-fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream {
+fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream {
     if fields.len() == 0 {
         return quote!();
     }
 
-    let field_types = fields.iter().map(|(_, ty)| ty);
+    let field_types = fields.iter().map(|f| &f.ty);
     quote! {
         fn uses_fd() -> bool {
             #(<#field_types>::uses_fd())||*
@@ -467,13 +516,13 @@ fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream {
     }
 }
 
-fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut read_fields = Vec::new();
     let mut init_fields = Vec::new();
-    for (idx, (_, field_ty)) in fields.iter().enumerate() {
+    for (idx, field) in fields.iter().enumerate() {
         let tmp_name = format!("tuple_tmp{}", idx);
         let tmp_name = Ident::new(&tmp_name, Span::call_site());
-        let read_field = read_from_buffer_and_move_offset(&tmp_name, field_ty);
+        let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty);
         read_fields.push(read_field);
         init_fields.push(quote!(#tmp_name));
     }
@@ -496,7 +545,7 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> Tok
     }
 }
 
-fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream {
+fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
     let mut write_fields = Vec::new();
     let mut tmp_names = Vec::new();
     for idx in 0..fields.len() {
@@ -520,8 +569,12 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> To
     }
 }
 /************************** Helpers ********************************************/
-fn get_fields_buffer_size_sum(fields: &[(Member, Type)]) -> (TokenStream, TokenStream) {
-    let fields: Vec<_> = fields.iter().map(|(m, _)| m).collect();
+fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) {
+    let fields: Vec<_> = fields
+        .iter()
+        .filter(|f| !f.skipped)
+        .map(|f| &f.member)
+        .collect();
     if fields.len() > 0 {
         (
             quote! {
@@ -832,4 +885,45 @@ mod tests {
 
         assert_eq!(msg_socket_impl(input).to_string(), expected.to_string());
     }
+
+    #[test]
+    fn end_to_end_struct_skip_test() {
+        let input: DeriveInput = parse_quote! {
+            struct MyMsg {
+                #[msg_on_socket(skip)]
+                a: u8,
+            }
+        };
+
+        let expected = quote! {
+            impl msg_socket::MsgOnSocket for MyMsg {
+                fn msg_size(&self) -> usize {
+                    0
+                }
+                fn fd_count(&self) -> usize {
+                    0
+                }
+                unsafe fn read_from_buffer(
+                    buffer: &[u8],
+                    fds: &[std::os::unix::io::RawFd],
+                ) -> msg_socket::MsgResult<(Self, usize)> {
+                    let mut __offset = 0usize;
+                    let mut __fd_offset = 0usize;
+                    Ok((Self { a: <u8>::default() }, __fd_offset))
+                }
+                fn write_to_buffer(
+                    &self,
+                    buffer: &mut [u8],
+                    fds: &mut [std::os::unix::io::RawFd],
+                ) -> msg_socket::MsgResult<usize> {
+                    let mut __offset = 0usize;
+                    let mut __fd_offset = 0usize;
+                    Ok(__fd_offset)
+                }
+            }
+
+        };
+
+        assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
+    }
 }
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
index 5871735..7f5d1a6 100644
--- a/msg_socket/src/lib.rs
+++ b/msg_socket/src/lib.rs
@@ -13,7 +13,7 @@ use std::task::{Context, Poll};
 use futures::Stream;
 use libc::{EWOULDBLOCK, O_NONBLOCK};
 
-use cros_async::fd_executor::add_read_waker;
+use cros_async::add_read_waker;
 use sys_util::{
     add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError,
     ScmSocket,
@@ -50,7 +50,7 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
     }
 
     // Creates an async receiver that implements `futures::Stream`.
-    pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> {
+    pub fn async_receiver(&self) -> MsgResult<AsyncReceiver<I, O>> {
         AsyncReceiver::new(self)
     }
 }
@@ -164,6 +164,20 @@ pub trait MsgReceiver: AsRef<UnixSeqpacket> {
                 )
             }
         };
+
+        if msg_buffer.len() == 0 && Self::M::fixed_size() != Some(0) {
+            return Err(MsgError::RecvZero);
+        }
+
+        if let Some(fixed_size) = Self::M::fixed_size() {
+            if fixed_size != msg_buffer.len() {
+                return Err(MsgError::BadRecvSize {
+                    expected: fixed_size,
+                    actual: msg_buffer.len(),
+                });
+            }
+        }
+
         // Safe because fd buffer is read from socket.
         let (v, read_fd_size) = unsafe { Self::M::read_from_buffer(&msg_buffer, &fd_buffer)? };
         if fd_buffer.len() != read_fd_size {
@@ -189,12 +203,12 @@ impl<O: MsgOnSocket> MsgReceiver for Receiver<O> {
 
 /// Asynchronous adaptor for `MsgSocket`.
 pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> {
-    inner: &'a mut MsgSocket<I, O>,
+    inner: &'a MsgSocket<I, O>,
     done: bool, // Have hit an error and the Stream should return null when polled.
 }
 
 impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> {
-    fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> {
+    fn new(msg_socket: &MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> {
         add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?;
         Ok(AsyncReceiver {
             inner: msg_socket,
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
index 624d514..67d26aa 100644
--- a/msg_socket/src/msg_on_socket.rs
+++ b/msg_socket/src/msg_on_socket.rs
@@ -10,15 +10,17 @@ use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
 use std::ptr::drop_in_place;
 use std::result;
+use std::sync::Arc;
 
 use data_model::*;
+use sync::Mutex;
 use sys_util::{Error as SysError, EventFd};
 
 #[derive(Debug, PartialEq)]
 /// An error during transaction or serialization/deserialization.
 pub enum MsgError {
     /// Error adding a waker for async read.
-    AddingWaker(cros_async::fd_executor::Error),
+    AddingWaker(cros_async::Error),
     /// Error while sending a request or response.
     Send(SysError),
     /// Error while receiving a request or response.
@@ -28,6 +30,8 @@ pub enum MsgError {
     /// There was not the expected amount of data when receiving a message. The inner
     /// value is how much data is expected and how much data was actually received.
     BadRecvSize { expected: usize, actual: usize },
+    /// There was no data received when the socket `recv`-ed.
+    RecvZero,
     /// There was no associated file descriptor received for a request that expected it.
     ExpectFd,
     /// There was some associated file descriptor received but not used when deserialize.
@@ -58,6 +62,7 @@ impl Display for MsgError {
                 "wrong amount of data received; expected {} bytes; got {} bytes",
                 expected, actual
             ),
+            RecvZero => write!(f, "received zero data"),
             ExpectFd => write!(f, "missing associated file descriptor for request"),
             NotExpectFd => write!(f, "unexpected file descriptor is unused"),
             SettingFdFlags(e) => write!(f, "failed setting flags on the message FD: {}", e),
@@ -207,6 +212,50 @@ impl<T: MsgOnSocket> MsgOnSocket for Option<T> {
     }
 }
 
+impl<T: MsgOnSocket> MsgOnSocket for Mutex<T> {
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn msg_size(&self) -> usize {
+        self.lock().msg_size()
+    }
+
+    fn fd_count(&self) -> usize {
+        self.lock().fd_count()
+    }
+
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        T::read_from_buffer(buffer, fds).map(|(v, count)| (Mutex::new(v), count))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        self.lock().write_to_buffer(buffer, fds)
+    }
+}
+
+impl<T: MsgOnSocket> MsgOnSocket for Arc<T> {
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn msg_size(&self) -> usize {
+        (**self).msg_size()
+    }
+
+    fn fd_count(&self) -> usize {
+        (**self).fd_count()
+    }
+
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        T::read_from_buffer(buffer, fds).map(|(v, count)| (Arc::new(v), count))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        (**self).write_to_buffer(buffer, fds)
+    }
+}
+
 impl MsgOnSocket for () {
     fn fixed_size() -> Option<usize> {
         Some(0)