diff options
author | Zach Reizner <zachr@google.com> | 2020-01-28 13:18:09 -0800 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2020-04-05 15:10:47 +0000 |
commit | 146450b4569e86657d1d8c4ffe17524781aae7e3 (patch) | |
tree | def385d4cf3e5c0e5f96169f8f9555800a9daadd /msg_socket/src/msg_on_socket.rs | |
parent | 773c70740e98c1aaf73a7b02e65eadaeab33c9d8 (diff) | |
download | crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.gz crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.bz2 crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.lz crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.xz crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.zst crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.zip |
msg_socket: support dynamically sized types
This change is a major shift in how the MsgOnSocket trait works to allow `self` to be used to determine the result `msg_size()`. This is to support data structures with `Vec` or other dynamically sized type. TEST=./build_test cargo test -p msg_socket tast run <DUT> crostini.CopyPaste.* BUG=None Cq-Depend: chromium:2025907 Change-Id: Ibdb51b377b2a2a77892f6c75e1a9f30b2f8b0240 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2029930 Tested-by: Zach Reizner <zachr@chromium.org> Auto-Submit: Zach Reizner <zachr@chromium.org> Reviewed-by: Zach Reizner <zachr@chromium.org> Commit-Queue: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'msg_socket/src/msg_on_socket.rs')
-rw-r--r-- | msg_socket/src/msg_on_socket.rs | 234 |
1 files changed, 175 insertions, 59 deletions
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs index f03c36f..3f34019 100644 --- a/msg_socket/src/msg_on_socket.rs +++ b/msg_socket/src/msg_on_socket.rs @@ -4,9 +4,11 @@ use std::fmt::{self, Display}; use std::fs::File; +use std::mem::{size_of, transmute_copy, MaybeUninit}; use std::net::{TcpListener, TcpStream, UdpSocket}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; +use std::ptr::drop_in_place; use std::result; use data_model::*; @@ -90,10 +92,24 @@ impl Display for MsgError { /// Thus, read/write functions always the return correct count of fds in this variant. There will be /// no padding in fd_buffer. pub trait MsgOnSocket: Sized { + // `true` if this structure can potentially serialize fds. + fn uses_fd() -> bool { + false + } + + // Returns `Some(size)` if this structure always has a fixed size. + fn fixed_size() -> Option<usize> { + None + } + /// Size of message in bytes. - fn msg_size() -> usize; - /// Max possible fd count in this type. - fn max_fd_count() -> usize { + fn msg_size(&self) -> usize { + Self::fixed_size().unwrap() + } + + /// Number of FDs in this message. This must be overridden if `uses_fd()` returns true. + fn fd_count(&self) -> usize { + assert!(!Self::uses_fd()); 0 } /// Returns (self, fd read count). @@ -103,13 +119,14 @@ pub trait MsgOnSocket: Sized { /// 2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding, /// return correct count). unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)>; + /// Serialize self to buffers. fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize>; } impl MsgOnSocket for SysError { - fn msg_size() -> usize { - u32::msg_size() + fn fixed_size() -> Option<usize> { + Some(size_of::<u32>()) } unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { let (v, size) = u32::read_from_buffer(buffer, fds)?; @@ -122,12 +139,14 @@ impl MsgOnSocket for SysError { } impl MsgOnSocket for RawFd { - fn msg_size() -> usize { - 0 + fn fixed_size() -> Option<usize> { + Some(0) } - fn max_fd_count() -> usize { + + fn fd_count(&self) -> usize { 1 } + unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { if fds.is_empty() { return Err(MsgError::ExpectFd); @@ -144,12 +163,22 @@ impl MsgOnSocket for RawFd { } impl<T: MsgOnSocket> MsgOnSocket for Option<T> { - fn msg_size() -> usize { - T::msg_size() + 1 + fn uses_fd() -> bool { + T::uses_fd() + } + + fn msg_size(&self) -> usize { + match self { + Some(v) => v.msg_size() + 1, + None => 0, + } } - fn max_fd_count() -> usize { - T::max_fd_count() + fn fd_count(&self) -> usize { + match self { + Some(v) => v.fd_count(), + None => 0, + } } unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { @@ -178,12 +207,8 @@ impl<T: MsgOnSocket> MsgOnSocket for Option<T> { } impl MsgOnSocket for () { - fn msg_size() -> usize { - 0 - } - - fn max_fd_count() -> usize { - 0 + fn fixed_size() -> Option<usize> { + Some(0) } unsafe fn read_from_buffer(_buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { @@ -198,17 +223,20 @@ impl MsgOnSocket for () { macro_rules! rawfd_impl { ($type:ident) => { impl MsgOnSocket for $type { - fn msg_size() -> usize { + fn uses_fd() -> bool { + true + } + fn msg_size(&self) -> usize { 0 } - fn max_fd_count() -> usize { + fn fd_count(&self) -> usize { 1 } unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { if fds.len() < 1 { return Err(MsgError::ExpectFd); } - Ok(($type::from_raw_fd(fds[0].clone()), 1)) + Ok(($type::from_raw_fd(fds[0]), 1)) } fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> { if fds.len() < 1 { @@ -246,11 +274,11 @@ where // usize could be different sizes on different targets. We always use u64. impl MsgOnSocket for usize { - fn msg_size() -> usize { - std::mem::size_of::<u64>() + fn msg_size(&self) -> usize { + size_of::<u64>() } unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { - if buffer.len() < std::mem::size_of::<u64>() { + if buffer.len() < size_of::<u64>() { return Err(MsgError::WrongMsgBufferSize); } let t = u64::from_le_bytes(slice_to_array(buffer)); @@ -258,22 +286,22 @@ impl MsgOnSocket for usize { } fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> { - if buffer.len() < std::mem::size_of::<u64>() { + if buffer.len() < size_of::<u64>() { return Err(MsgError::WrongMsgBufferSize); } let t: Le64 = (*self as u64).into(); - buffer[0..Self::msg_size()].copy_from_slice(t.as_slice()); + buffer[0..self.msg_size()].copy_from_slice(t.as_slice()); Ok(0) } } // Encode bool as a u8 of value 0 or 1 impl MsgOnSocket for bool { - fn msg_size() -> usize { - std::mem::size_of::<u8>() + fn msg_size(&self) -> usize { + size_of::<u8>() } unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { - if buffer.len() < std::mem::size_of::<u8>() { + if buffer.len() < size_of::<u8>() { return Err(MsgError::WrongMsgBufferSize); } let t: u8 = buffer[0]; @@ -284,7 +312,7 @@ impl MsgOnSocket for bool { } } fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> { - if buffer.len() < std::mem::size_of::<u8>() { + if buffer.len() < size_of::<u8>() { return Err(MsgError::WrongMsgBufferSize); } buffer[0] = *self as u8; @@ -295,11 +323,12 @@ impl MsgOnSocket for bool { macro_rules! le_impl { ($type:ident, $native_type:ident) => { impl MsgOnSocket for $type { - fn msg_size() -> usize { - std::mem::size_of::<$native_type>() + fn fixed_size() -> Option<usize> { + Some(size_of::<$native_type>()) } + unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> { - if buffer.len() < std::mem::size_of::<$native_type>() { + if buffer.len() < size_of::<$native_type>() { return Err(MsgError::WrongMsgBufferSize); } let t = $native_type::from_le_bytes(slice_to_array(buffer)); @@ -307,11 +336,11 @@ macro_rules! le_impl { } fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> { - if buffer.len() < std::mem::size_of::<$native_type>() { + if buffer.len() < size_of::<$native_type>() { return Err(MsgError::WrongMsgBufferSize); } let t: $native_type = self.clone().into(); - buffer[0..Self::msg_size()].copy_from_slice(&t.to_le_bytes()); + buffer[0..self.msg_size()].copy_from_slice(&t.to_le_bytes()); Ok(0) } } @@ -331,30 +360,85 @@ macro_rules! array_impls { ($N:expr, $t: ident $($ts:ident)*) => { impl<T: MsgOnSocket + Clone> MsgOnSocket for [T; $N] { - fn msg_size() -> usize { - T::msg_size() * $N + fn uses_fd() -> bool { + T::uses_fd() } - fn max_fd_count() -> usize { - T::max_fd_count() * $N + + fn fixed_size() -> Option<usize> { + Some(T::fixed_size()? * $N) } - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { - if buffer.len() < Self::msg_size() { - return Err(MsgError::WrongMsgBufferSize); + + fn msg_size(&self) -> usize { + match T::fixed_size() { + Some(s) => s * $N, + None => self.iter().map(|i| i.msg_size()).sum::<usize>() + size_of::<u64>() * $N } + } + + fn fd_count(&self) -> usize { + if T::uses_fd() { + self.iter().map(|i| i.fd_count()).sum() + } else { + 0 + } + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + // Taken from the canonical example of initializing an array, the `assume_init` can + // be assumed safe because the array elements (`MaybeUninit<T>` in this case) + // themselves don't require initializing. + let mut msgs: [MaybeUninit<T>; $N] = MaybeUninit::uninit().assume_init(); + let mut offset = 0usize; let mut fd_offset = 0usize; - let ($t, fd_size) = - T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; - offset += T::msg_size(); - fd_offset += fd_size; - $( - let ($ts, fd_size) = - T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; - offset += T::msg_size(); - fd_offset += fd_size; - )* - assert_eq!(offset, Self::msg_size()); - Ok(([$t, $($ts),*], fd_offset)) + + // In case of an error, we need to keep track of how many elements got initialized. + // In order to perform the necessary drops, the below loop is executed in a closure + // to capture errors without returning. + let mut last_index = 0; + let res = (|| { + for msg in &mut msgs[..] { + let element_size = match T::fixed_size() { + Some(s) => s, + None => { + let (element_size, _) = u64::read_from_buffer(&buffer[offset..], &[])?; + offset += element_size.msg_size(); + element_size as usize + } + }; + let (m, fd_size) = + T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; + *msg = MaybeUninit::new(m); + offset += element_size; + fd_offset += fd_size; + last_index += 1; + } + Ok(()) + })(); + + // Because `MaybeUninit` will not automatically call drops, we have to drop the + // partially initialized array manually in the case of an error. + if let Err(e) = res { + for msg in &mut msgs[..last_index] { + // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array + // into a pointer, which can be used with `drop_in_place` to call the + // destructor without moving the element, which is impossible. This is safe + // because `last_index` prevents this loop from traversing into the + // uninitialized parts of the array. + drop_in_place(msg.as_mut_ptr()); + } + return Err(e) + } + + // Also taken from the canonical example, we initialized every member of the array + // in the first loop of this function, so it is safe to `transmute_copy` the array + // of `MaybeUninit` data to plain data. Although `transmute`, which checks the + // types' sizes, would have been preferred in this code, the compiler complains with + // "cannot transmute between types of different sizes, or dependently-sized types." + // Because this function operates on generic data, the type is "dependently-sized" + // and so the compiler will not check that the size of the input and output match. + // See this issue for details: https://github.com/rust-lang/rust/issues/61956 + Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_offset)) } fn write_to_buffer( @@ -362,21 +446,53 @@ macro_rules! array_impls { buffer: &mut [u8], fds: &mut [RawFd], ) -> MsgResult<usize> { - if buffer.len() < Self::msg_size() { - return Err(MsgError::WrongMsgBufferSize); - } let mut offset = 0usize; let mut fd_offset = 0usize; for idx in 0..$N { - let fd_size = self[idx].clone().write_to_buffer(&mut buffer[offset..], + let element_size = match T::fixed_size() { + Some(s) => s, + None => { + let element_size = self[idx].msg_size() as u64; + element_size.write_to_buffer(&mut buffer[offset..], &mut [])?; + offset += element_size.msg_size(); + element_size as usize + } + }; + let fd_size = self[idx].write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?; - offset += T::msg_size(); + offset += element_size; fd_offset += fd_size; } Ok(fd_offset) } } + #[cfg(test)] + mod $t { + use super::MsgOnSocket; + + #[test] + fn read_write_option_array() { + type ArrayType = [Option<u32>; $N]; + let array = [Some($N); $N]; + let mut buffer = vec![0; array.msg_size()]; + array.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0; + + assert_eq!(array, read_array); + } + + #[test] + fn read_write_fixed() { + type ArrayType = [u32; $N]; + let mut buffer = vec![0; <ArrayType>::fixed_size().unwrap()]; + let array = [$N as u32; $N]; + array.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0; + + assert_eq!(array, read_array); + } + } array_impls!(($N - 1), $($ts)*); }; {$N:expr, } => {}; |