summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket.rs
diff options
context:
space:
mode:
authorZach Reizner <zachr@google.com>2020-01-28 13:18:09 -0800
committerCommit Bot <commit-bot@chromium.org>2020-04-05 15:10:47 +0000
commit146450b4569e86657d1d8c4ffe17524781aae7e3 (patch)
treedef385d4cf3e5c0e5f96169f8f9555800a9daadd /msg_socket/src/msg_on_socket.rs
parent773c70740e98c1aaf73a7b02e65eadaeab33c9d8 (diff)
downloadcrosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.gz
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.bz2
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.lz
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.xz
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.tar.zst
crosvm-146450b4569e86657d1d8c4ffe17524781aae7e3.zip
msg_socket: support dynamically sized types
This change is a major shift in how the MsgOnSocket trait works to allow
`self` to be used to determine the result `msg_size()`. This is to
support data structures with `Vec` or other dynamically sized type.

TEST=./build_test
     cargo test -p msg_socket
     tast run <DUT> crostini.CopyPaste.*
BUG=None

Cq-Depend: chromium:2025907
Change-Id: Ibdb51b377b2a2a77892f6c75e1a9f30b2f8b0240
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2029930
Tested-by: Zach Reizner <zachr@chromium.org>
Auto-Submit: Zach Reizner <zachr@chromium.org>
Reviewed-by: Zach Reizner <zachr@chromium.org>
Commit-Queue: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'msg_socket/src/msg_on_socket.rs')
-rw-r--r--msg_socket/src/msg_on_socket.rs234
1 files changed, 175 insertions, 59 deletions
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
index f03c36f..3f34019 100644
--- a/msg_socket/src/msg_on_socket.rs
+++ b/msg_socket/src/msg_on_socket.rs
@@ -4,9 +4,11 @@
 
 use std::fmt::{self, Display};
 use std::fs::File;
+use std::mem::{size_of, transmute_copy, MaybeUninit};
 use std::net::{TcpListener, TcpStream, UdpSocket};
 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
+use std::ptr::drop_in_place;
 use std::result;
 
 use data_model::*;
@@ -90,10 +92,24 @@ impl Display for MsgError {
 /// Thus, read/write functions always the return correct count of fds in this variant. There will be
 /// no padding in fd_buffer.
 pub trait MsgOnSocket: Sized {
+    // `true` if this structure can potentially serialize fds.
+    fn uses_fd() -> bool {
+        false
+    }
+
+    // Returns `Some(size)` if this structure always has a fixed size.
+    fn fixed_size() -> Option<usize> {
+        None
+    }
+
     /// Size of message in bytes.
-    fn msg_size() -> usize;
-    /// Max possible fd count in this type.
-    fn max_fd_count() -> usize {
+    fn msg_size(&self) -> usize {
+        Self::fixed_size().unwrap()
+    }
+
+    /// Number of FDs in this message. This must be overridden if `uses_fd()` returns true.
+    fn fd_count(&self) -> usize {
+        assert!(!Self::uses_fd());
         0
     }
     /// Returns (self, fd read count).
@@ -103,13 +119,14 @@ pub trait MsgOnSocket: Sized {
     ///     2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding,
     ///        return correct count).
     unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)>;
+
     /// Serialize self to buffers.
     fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize>;
 }
 
 impl MsgOnSocket for SysError {
-    fn msg_size() -> usize {
-        u32::msg_size()
+    fn fixed_size() -> Option<usize> {
+        Some(size_of::<u32>())
     }
     unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
         let (v, size) = u32::read_from_buffer(buffer, fds)?;
@@ -122,12 +139,14 @@ impl MsgOnSocket for SysError {
 }
 
 impl MsgOnSocket for RawFd {
-    fn msg_size() -> usize {
-        0
+    fn fixed_size() -> Option<usize> {
+        Some(0)
     }
-    fn max_fd_count() -> usize {
+
+    fn fd_count(&self) -> usize {
         1
     }
+
     unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
         if fds.is_empty() {
             return Err(MsgError::ExpectFd);
@@ -144,12 +163,22 @@ impl MsgOnSocket for RawFd {
 }
 
 impl<T: MsgOnSocket> MsgOnSocket for Option<T> {
-    fn msg_size() -> usize {
-        T::msg_size() + 1
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn msg_size(&self) -> usize {
+        match self {
+            Some(v) => v.msg_size() + 1,
+            None => 0,
+        }
     }
 
-    fn max_fd_count() -> usize {
-        T::max_fd_count()
+    fn fd_count(&self) -> usize {
+        match self {
+            Some(v) => v.fd_count(),
+            None => 0,
+        }
     }
 
     unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
@@ -178,12 +207,8 @@ impl<T: MsgOnSocket> MsgOnSocket for Option<T> {
 }
 
 impl MsgOnSocket for () {
-    fn msg_size() -> usize {
-        0
-    }
-
-    fn max_fd_count() -> usize {
-        0
+    fn fixed_size() -> Option<usize> {
+        Some(0)
     }
 
     unsafe fn read_from_buffer(_buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
@@ -198,17 +223,20 @@ impl MsgOnSocket for () {
 macro_rules! rawfd_impl {
     ($type:ident) => {
         impl MsgOnSocket for $type {
-            fn msg_size() -> usize {
+            fn uses_fd() -> bool {
+                true
+            }
+            fn msg_size(&self) -> usize {
                 0
             }
-            fn max_fd_count() -> usize {
+            fn fd_count(&self) -> usize {
                 1
             }
             unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
                 if fds.len() < 1 {
                     return Err(MsgError::ExpectFd);
                 }
-                Ok(($type::from_raw_fd(fds[0].clone()), 1))
+                Ok(($type::from_raw_fd(fds[0]), 1))
             }
             fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
                 if fds.len() < 1 {
@@ -246,11 +274,11 @@ where
 
 // usize could be different sizes on different targets. We always use u64.
 impl MsgOnSocket for usize {
-    fn msg_size() -> usize {
-        std::mem::size_of::<u64>()
+    fn msg_size(&self) -> usize {
+        size_of::<u64>()
     }
     unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
-        if buffer.len() < std::mem::size_of::<u64>() {
+        if buffer.len() < size_of::<u64>() {
             return Err(MsgError::WrongMsgBufferSize);
         }
         let t = u64::from_le_bytes(slice_to_array(buffer));
@@ -258,22 +286,22 @@ impl MsgOnSocket for usize {
     }
 
     fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> {
-        if buffer.len() < std::mem::size_of::<u64>() {
+        if buffer.len() < size_of::<u64>() {
             return Err(MsgError::WrongMsgBufferSize);
         }
         let t: Le64 = (*self as u64).into();
-        buffer[0..Self::msg_size()].copy_from_slice(t.as_slice());
+        buffer[0..self.msg_size()].copy_from_slice(t.as_slice());
         Ok(0)
     }
 }
 
 // Encode bool as a u8 of value 0 or 1
 impl MsgOnSocket for bool {
-    fn msg_size() -> usize {
-        std::mem::size_of::<u8>()
+    fn msg_size(&self) -> usize {
+        size_of::<u8>()
     }
     unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
-        if buffer.len() < std::mem::size_of::<u8>() {
+        if buffer.len() < size_of::<u8>() {
             return Err(MsgError::WrongMsgBufferSize);
         }
         let t: u8 = buffer[0];
@@ -284,7 +312,7 @@ impl MsgOnSocket for bool {
         }
     }
     fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> {
-        if buffer.len() < std::mem::size_of::<u8>() {
+        if buffer.len() < size_of::<u8>() {
             return Err(MsgError::WrongMsgBufferSize);
         }
         buffer[0] = *self as u8;
@@ -295,11 +323,12 @@ impl MsgOnSocket for bool {
 macro_rules! le_impl {
     ($type:ident, $native_type:ident) => {
         impl MsgOnSocket for $type {
-            fn msg_size() -> usize {
-                std::mem::size_of::<$native_type>()
+            fn fixed_size() -> Option<usize> {
+                Some(size_of::<$native_type>())
             }
+
             unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
-                if buffer.len() < std::mem::size_of::<$native_type>() {
+                if buffer.len() < size_of::<$native_type>() {
                     return Err(MsgError::WrongMsgBufferSize);
                 }
                 let t = $native_type::from_le_bytes(slice_to_array(buffer));
@@ -307,11 +336,11 @@ macro_rules! le_impl {
             }
 
             fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> {
-                if buffer.len() < std::mem::size_of::<$native_type>() {
+                if buffer.len() < size_of::<$native_type>() {
                     return Err(MsgError::WrongMsgBufferSize);
                 }
                 let t: $native_type = self.clone().into();
-                buffer[0..Self::msg_size()].copy_from_slice(&t.to_le_bytes());
+                buffer[0..self.msg_size()].copy_from_slice(&t.to_le_bytes());
                 Ok(0)
             }
         }
@@ -331,30 +360,85 @@ macro_rules! array_impls {
     ($N:expr, $t: ident $($ts:ident)*)
     => {
         impl<T: MsgOnSocket + Clone> MsgOnSocket for [T; $N] {
-            fn msg_size() -> usize {
-                T::msg_size() * $N
+            fn uses_fd() -> bool {
+                T::uses_fd()
             }
-            fn max_fd_count() -> usize {
-                T::max_fd_count() * $N
+
+            fn fixed_size() -> Option<usize> {
+                Some(T::fixed_size()? * $N)
             }
-            unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
-                if buffer.len() < Self::msg_size() {
-                    return Err(MsgError::WrongMsgBufferSize);
+
+            fn msg_size(&self) -> usize {
+                match T::fixed_size() {
+                    Some(s) => s * $N,
+                    None => self.iter().map(|i| i.msg_size()).sum::<usize>() + size_of::<u64>() * $N
                 }
+            }
+
+            fn fd_count(&self) -> usize {
+                if T::uses_fd() {
+                    self.iter().map(|i| i.fd_count()).sum()
+                } else {
+                    0
+                }
+            }
+
+            unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                // Taken from the canonical example of initializing an array, the `assume_init` can
+                // be assumed safe because the array elements (`MaybeUninit<T>` in this case)
+                // themselves don't require initializing.
+                let mut msgs: [MaybeUninit<T>; $N] =  MaybeUninit::uninit().assume_init();
+
                 let mut offset = 0usize;
                 let mut fd_offset = 0usize;
-                let ($t, fd_size) =
-                    T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
-                offset += T::msg_size();
-                fd_offset += fd_size;
-                $(
-                    let ($ts, fd_size) =
-                        T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
-                    offset += T::msg_size();
-                    fd_offset += fd_size;
-                    )*
-                assert_eq!(offset, Self::msg_size());
-                Ok(([$t, $($ts),*], fd_offset))
+
+                // In case of an error, we need to keep track of how many elements got initialized.
+                // In order to perform the necessary drops, the below loop is executed in a closure
+                // to capture errors without returning.
+                let mut last_index = 0;
+                let res = (|| {
+                    for msg in &mut msgs[..] {
+                        let element_size = match T::fixed_size() {
+                            Some(s) => s,
+                            None => {
+                                let (element_size, _) = u64::read_from_buffer(&buffer[offset..], &[])?;
+                                offset += element_size.msg_size();
+                                element_size as usize
+                            }
+                        };
+                        let (m, fd_size) =
+                            T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
+                        *msg = MaybeUninit::new(m);
+                        offset += element_size;
+                        fd_offset += fd_size;
+                        last_index += 1;
+                    }
+                    Ok(())
+                })();
+
+                // Because `MaybeUninit` will not automatically call drops, we have to drop the
+                // partially initialized array manually in the case of an error.
+                if let Err(e) = res {
+                    for msg in &mut msgs[..last_index] {
+                        // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
+                        // into a pointer, which can be used with `drop_in_place` to call the
+                        // destructor without moving the element, which is impossible. This is safe
+                        // because `last_index` prevents this loop from traversing into the
+                        // uninitialized parts of the array.
+                        drop_in_place(msg.as_mut_ptr());
+                    }
+                    return Err(e)
+                }
+
+                // Also taken from the canonical example, we initialized every member of the array
+                // in the first loop of this function, so it is safe to `transmute_copy` the array
+                // of `MaybeUninit` data to plain data. Although `transmute`, which checks the
+                // types' sizes, would have been preferred in this code, the compiler complains with
+                // "cannot transmute between types of different sizes, or dependently-sized types."
+                // Because this function operates on generic data, the type is "dependently-sized"
+                // and so the compiler will not check that the size of the input and output match.
+                // See this issue for details: https://github.com/rust-lang/rust/issues/61956
+                Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_offset))
             }
 
             fn write_to_buffer(
@@ -362,21 +446,53 @@ macro_rules! array_impls {
                 buffer: &mut [u8],
                 fds: &mut [RawFd],
                 ) -> MsgResult<usize> {
-                if buffer.len() < Self::msg_size() {
-                    return Err(MsgError::WrongMsgBufferSize);
-                }
                 let mut offset = 0usize;
                 let mut fd_offset = 0usize;
                 for idx in 0..$N {
-                    let fd_size = self[idx].clone().write_to_buffer(&mut buffer[offset..],
+                    let element_size = match T::fixed_size() {
+                        Some(s) => s,
+                        None => {
+                            let element_size = self[idx].msg_size() as u64;
+                            element_size.write_to_buffer(&mut buffer[offset..], &mut [])?;
+                            offset += element_size.msg_size();
+                            element_size as usize
+                        }
+                    };
+                    let fd_size = self[idx].write_to_buffer(&mut buffer[offset..],
                                                             &mut fds[fd_offset..])?;
-                    offset += T::msg_size();
+                    offset += element_size;
                     fd_offset += fd_size;
                 }
 
                 Ok(fd_offset)
             }
         }
+        #[cfg(test)]
+        mod $t {
+            use super::MsgOnSocket;
+
+            #[test]
+            fn read_write_option_array() {
+                type ArrayType = [Option<u32>; $N];
+                let array = [Some($N); $N];
+                let mut buffer = vec![0; array.msg_size()];
+                array.write_to_buffer(&mut buffer, &mut []).unwrap();
+                let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
+
+                assert_eq!(array, read_array);
+            }
+
+            #[test]
+            fn read_write_fixed() {
+                type ArrayType = [u32; $N];
+                let mut buffer = vec![0; <ArrayType>::fixed_size().unwrap()];
+                let array = [$N as u32; $N];
+                array.write_to_buffer(&mut buffer, &mut []).unwrap();
+                let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
+
+                assert_eq!(array, read_array);
+            }
+        }
         array_impls!(($N - 1), $($ts)*);
     };
     {$N:expr, } => {};