diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-05-08 15:27:56 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2020-05-10 02:39:28 +0000 |
commit | 2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b (patch) | |
tree | fefaf2c13796f8f2fa9a13b99b09c3b40ab5966b /msg_socket | |
parent | 00c41c28bbc44b37fc8dcf5d2a5b4679f2aa4297 (diff) | |
parent | 03a54abf852984f696e7a101ff9590f05ebcba5b (diff) | |
download | crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.gz crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.bz2 crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.lz crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.xz crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.tar.zst crosvm-2f8d50adc97cc7fca6f710bd575b4f71ccb40f6b.zip |
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'msg_socket')
-rw-r--r-- | msg_socket/Cargo.toml | 1 | ||||
-rw-r--r-- | msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs | 168 | ||||
-rw-r--r-- | msg_socket/src/lib.rs | 22 | ||||
-rw-r--r-- | msg_socket/src/msg_on_socket.rs | 51 |
4 files changed, 200 insertions, 42 deletions
diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml index c803bed..80eba0b 100644 --- a/msg_socket/Cargo.toml +++ b/msg_socket/Cargo.toml @@ -11,3 +11,4 @@ futures = "*" libc = "*" msg_on_socket_derive = { path = "msg_on_socket_derive" } sys_util = { path = "../sys_util" } +sync = { path = "../sync" } diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs index 7b16546..a5f31f8 100644 --- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs +++ b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs @@ -9,11 +9,12 @@ use std::vec::Vec; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Type, + parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta, + NestedMeta, Type, }; /// The function that derives the recursive implementation for struct or enum. -#[proc_macro_derive(MsgOnSocket)] +#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))] pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let impl_for_input = msg_socket_impl(input); @@ -46,6 +47,13 @@ fn is_named_struct(ds: &DataStruct) -> bool { } /************************** Named Struct Impls ********************************************/ + +struct StructField { + member: Member, + ty: Type, + skipped: bool, +} + fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { let fields = get_struct_fields(ds); let uses_fd_impl = define_uses_fd_for_struct(&fields); @@ -64,7 +72,7 @@ fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { } // Flatten struct fields. -fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> { +fn get_struct_fields(ds: DataStruct) -> Vec<StructField> { let fields = match ds.fields { Fields::Named(fields_named) => fields_named.named, _ => { @@ -78,17 +86,48 @@ fn get_struct_fields(ds: DataStruct) -> Vec<(Member, Type)> { None => panic!("Unknown Error."), }; let ty = field.ty; - vec.push((member, ty)); + let mut skipped = false; + for attr in field + .attrs + .iter() + .filter(|attr| attr.path.is_ident("msg_on_socket")) + { + match attr.parse_meta().unwrap() { + Meta::List(meta) => { + for nested in meta.nested { + match nested { + NestedMeta::Meta(Meta::Path(meta_path)) + if meta_path.is_ident("skip") => + { + skipped = true; + } + _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }), + } + } + } + _ => panic!("unrecognized attribute `{}`", quote! { #attr }), + } + } + vec.push(StructField { + member, + ty, + skipped, + }); } vec } -fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream { - if fields.len() == 0 { +fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream { + let field_types: Vec<_> = fields + .iter() + .filter(|f| !f.skipped) + .map(|f| &f.ty) + .collect(); + + if field_types.is_empty() { return quote!(); } - let field_types = fields.iter().map(|(_, ty)| ty); quote! { fn uses_fd() -> bool { #(<#field_types>::uses_fd())||* @@ -96,7 +135,7 @@ fn define_uses_fd_for_struct(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream { +fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream { let (msg_size, fd_count) = get_fields_buffer_size_sum(fields); quote! { fn msg_size(&self) -> usize { @@ -108,17 +147,24 @@ fn define_buffer_size_for_struct(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { let mut read_fields = Vec::new(); let mut init_fields = Vec::new(); - for (field_member, field_ty) in fields { - let ident = match field_member { + for field in fields { + let ident = match &field.member { Member::Named(ident) => ident, Member::Unnamed(_) => unreachable!(), }; - let read_field = read_from_buffer_and_move_offset(&ident, &field_ty); - read_fields.push(read_field); let name = ident.clone(); + if field.skipped { + let ty = &field.ty; + init_fields.push(quote! { + #name: <#ty>::default() + }); + continue; + } + let read_field = read_from_buffer_and_move_offset(&ident, &field.ty); + read_fields.push(read_field); init_fields.push(quote!(#name)); } quote! { @@ -139,10 +185,13 @@ fn define_read_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> To } } -fn define_write_buffer_for_struct(_name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { let mut write_fields = Vec::new(); - for (field_member, _) in fields { - let ident = match field_member { + for field in fields { + if field.skipped { + continue; + } + let ident = match &field.member { Member::Named(ident) => ident, Member::Unnamed(_) => unreachable!(), }; @@ -187,17 +236,13 @@ fn define_uses_fd_for_enum(de: &DataEnum) -> TokenStream { } } - if variant_field_types.is_empty() { - quote! { - fn uses_fd() -> bool { - false - } - } - } else { - quote! { - fn uses_fd() -> bool { - #(<#variant_field_types>::uses_fd())||* - } + if variant_field_types.len() == 0 { + return quote!(); + } + + quote! { + fn uses_fd() -> bool { + #(<#variant_field_types>::uses_fd())||* } } } @@ -438,7 +483,7 @@ fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream { } } -fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> { +fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> { let mut field_idents = Vec::new(); let fields = match ds.fields { Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed, @@ -449,17 +494,21 @@ fn get_tuple_fields(ds: DataStruct) -> Vec<(Member, Type)> { for (idx, field) in fields.iter().enumerate() { let member = Member::Unnamed(Index::from(idx)); let ty = field.ty.clone(); - field_idents.push((member, ty)); + field_idents.push(StructField { + member, + ty, + skipped: false, + }); } field_idents } -fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream { +fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream { if fields.len() == 0 { return quote!(); } - let field_types = fields.iter().map(|(_, ty)| ty); + let field_types = fields.iter().map(|f| &f.ty); quote! { fn uses_fd() -> bool { #(<#field_types>::uses_fd())||* @@ -467,13 +516,13 @@ fn define_uses_fd_for_tuples(fields: &[(Member, Type)]) -> TokenStream { } } -fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { let mut read_fields = Vec::new(); let mut init_fields = Vec::new(); - for (idx, (_, field_ty)) in fields.iter().enumerate() { + for (idx, field) in fields.iter().enumerate() { let tmp_name = format!("tuple_tmp{}", idx); let tmp_name = Ident::new(&tmp_name, Span::call_site()); - let read_field = read_from_buffer_and_move_offset(&tmp_name, field_ty); + let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty); read_fields.push(read_field); init_fields.push(quote!(#tmp_name)); } @@ -496,7 +545,7 @@ fn define_read_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> Tok } } -fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> TokenStream { +fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { let mut write_fields = Vec::new(); let mut tmp_names = Vec::new(); for idx in 0..fields.len() { @@ -520,8 +569,12 @@ fn define_write_buffer_for_tuples(name: &Ident, fields: &[(Member, Type)]) -> To } } /************************** Helpers ********************************************/ -fn get_fields_buffer_size_sum(fields: &[(Member, Type)]) -> (TokenStream, TokenStream) { - let fields: Vec<_> = fields.iter().map(|(m, _)| m).collect(); +fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) { + let fields: Vec<_> = fields + .iter() + .filter(|f| !f.skipped) + .map(|f| &f.member) + .collect(); if fields.len() > 0 { ( quote! { @@ -832,4 +885,45 @@ mod tests { assert_eq!(msg_socket_impl(input).to_string(), expected.to_string()); } + + #[test] + fn end_to_end_struct_skip_test() { + let input: DeriveInput = parse_quote! { + struct MyMsg { + #[msg_on_socket(skip)] + a: u8, + } + }; + + let expected = quote! { + impl msg_socket::MsgOnSocket for MyMsg { + fn msg_size(&self) -> usize { + 0 + } + fn fd_count(&self) -> usize { + 0 + } + unsafe fn read_from_buffer( + buffer: &[u8], + fds: &[std::os::unix::io::RawFd], + ) -> msg_socket::MsgResult<(Self, usize)> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + Ok((Self { a: <u8>::default() }, __fd_offset)) + } + fn write_to_buffer( + &self, + buffer: &mut [u8], + fds: &mut [std::os::unix::io::RawFd], + ) -> msg_socket::MsgResult<usize> { + let mut __offset = 0usize; + let mut __fd_offset = 0usize; + Ok(__fd_offset) + } + } + + }; + + assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); + } } diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index 5871735..7f5d1a6 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -13,7 +13,7 @@ use std::task::{Context, Poll}; use futures::Stream; use libc::{EWOULDBLOCK, O_NONBLOCK}; -use cros_async::fd_executor::add_read_waker; +use cros_async::add_read_waker; use sys_util::{ add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket, @@ -50,7 +50,7 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> { } // Creates an async receiver that implements `futures::Stream`. - pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> { + pub fn async_receiver(&self) -> MsgResult<AsyncReceiver<I, O>> { AsyncReceiver::new(self) } } @@ -164,6 +164,20 @@ pub trait MsgReceiver: AsRef<UnixSeqpacket> { ) } }; + + if msg_buffer.len() == 0 && Self::M::fixed_size() != Some(0) { + return Err(MsgError::RecvZero); + } + + if let Some(fixed_size) = Self::M::fixed_size() { + if fixed_size != msg_buffer.len() { + return Err(MsgError::BadRecvSize { + expected: fixed_size, + actual: msg_buffer.len(), + }); + } + } + // Safe because fd buffer is read from socket. let (v, read_fd_size) = unsafe { Self::M::read_from_buffer(&msg_buffer, &fd_buffer)? }; if fd_buffer.len() != read_fd_size { @@ -189,12 +203,12 @@ impl<O: MsgOnSocket> MsgReceiver for Receiver<O> { /// Asynchronous adaptor for `MsgSocket`. pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> { - inner: &'a mut MsgSocket<I, O>, + inner: &'a MsgSocket<I, O>, done: bool, // Have hit an error and the Stream should return null when polled. } impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> { - fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> { + fn new(msg_socket: &MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> { add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?; Ok(AsyncReceiver { inner: msg_socket, diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs index 624d514..67d26aa 100644 --- a/msg_socket/src/msg_on_socket.rs +++ b/msg_socket/src/msg_on_socket.rs @@ -10,15 +10,17 @@ use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; use std::ptr::drop_in_place; use std::result; +use std::sync::Arc; use data_model::*; +use sync::Mutex; use sys_util::{Error as SysError, EventFd}; #[derive(Debug, PartialEq)] /// An error during transaction or serialization/deserialization. pub enum MsgError { /// Error adding a waker for async read. - AddingWaker(cros_async::fd_executor::Error), + AddingWaker(cros_async::Error), /// Error while sending a request or response. Send(SysError), /// Error while receiving a request or response. @@ -28,6 +30,8 @@ pub enum MsgError { /// There was not the expected amount of data when receiving a message. The inner /// value is how much data is expected and how much data was actually received. BadRecvSize { expected: usize, actual: usize }, + /// There was no data received when the socket `recv`-ed. + RecvZero, /// There was no associated file descriptor received for a request that expected it. ExpectFd, /// There was some associated file descriptor received but not used when deserialize. @@ -58,6 +62,7 @@ impl Display for MsgError { "wrong amount of data received; expected {} bytes; got {} bytes", expected, actual ), + RecvZero => write!(f, "received zero data"), ExpectFd => write!(f, "missing associated file descriptor for request"), NotExpectFd => write!(f, "unexpected file descriptor is unused"), SettingFdFlags(e) => write!(f, "failed setting flags on the message FD: {}", e), @@ -207,6 +212,50 @@ impl<T: MsgOnSocket> MsgOnSocket for Option<T> { } } +impl<T: MsgOnSocket> MsgOnSocket for Mutex<T> { + fn uses_fd() -> bool { + T::uses_fd() + } + + fn msg_size(&self) -> usize { + self.lock().msg_size() + } + + fn fd_count(&self) -> usize { + self.lock().fd_count() + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + T::read_from_buffer(buffer, fds).map(|(v, count)| (Mutex::new(v), count)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> { + self.lock().write_to_buffer(buffer, fds) + } +} + +impl<T: MsgOnSocket> MsgOnSocket for Arc<T> { + fn uses_fd() -> bool { + T::uses_fd() + } + + fn msg_size(&self) -> usize { + (**self).msg_size() + } + + fn fd_count(&self) -> usize { + (**self).fd_count() + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + T::read_from_buffer(buffer, fds).map(|(v, count)| (Arc::new(v), count)) + } + + fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> { + (**self).write_to_buffer(buffer, fds) + } +} + impl MsgOnSocket for () { fn fixed_size() -> Option<usize> { Some(0) |