summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket
diff options
context:
space:
mode:
authorZach Reizner <zachr@google.com>2020-04-27 12:52:08 -0700
committerCommit Bot <commit-bot@chromium.org>2020-06-03 22:41:32 +0000
commit5f1a64892b714885f6c7405084a390467c03201a (patch)
tree8a045d36646dcdefffc475b1fd30a2080273a91a /msg_socket/src/msg_on_socket
parent1d6967437edc98716e545b82a28c788febfbe79a (diff)
downloadcrosvm-5f1a64892b714885f6c7405084a390467c03201a.tar
crosvm-5f1a64892b714885f6c7405084a390467c03201a.tar.gz
crosvm-5f1a64892b714885f6c7405084a390467c03201a.tar.bz2
crosvm-5f1a64892b714885f6c7405084a390467c03201a.tar.lz
crosvm-5f1a64892b714885f6c7405084a390467c03201a.tar.xz
crosvm-5f1a64892b714885f6c7405084a390467c03201a.tar.zst
crosvm-5f1a64892b714885f6c7405084a390467c03201a.zip
msg_socket: implement MsgOnSocket for Vec and tuples
These container types are similar to arrays except tuples have
heterogeneous data types and Vec has a dynamic number of elements.

BUG=None
TEST=cargo test -p msg_socket

Change-Id: I2cbbaeb7f13b7700294ac50751530510098ba16d
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2168588
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Tested-by: Zach Reizner <zachr@chromium.org>
Commit-Queue: Zach Reizner <zachr@chromium.org>
Diffstat (limited to 'msg_socket/src/msg_on_socket')
-rw-r--r--msg_socket/src/msg_on_socket/slice.rs184
-rw-r--r--msg_socket/src/msg_on_socket/tuple.rs205
2 files changed, 389 insertions, 0 deletions
diff --git a/msg_socket/src/msg_on_socket/slice.rs b/msg_socket/src/msg_on_socket/slice.rs
new file mode 100644
index 0000000..7b6ef28
--- /dev/null
+++ b/msg_socket/src/msg_on_socket/slice.rs
@@ -0,0 +1,184 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::mem::{size_of, ManuallyDrop, MaybeUninit};
+use std::os::unix::io::RawFd;
+use std::ptr::drop_in_place;
+
+use crate::{MsgOnSocket, MsgResult};
+
+use super::{simple_read, simple_write};
+
+/// Helper used by the types that read a slice of homegenously typed data.
+///
+/// # Safety
+/// This function has the same safety requirements as `T::read_from_buffer`, with the additional
+/// requirements that the `msgs` are only used on success of this function
+pub unsafe fn slice_read_helper<T: MsgOnSocket>(
+    buffer: &[u8],
+    fds: &[RawFd],
+    msgs: &mut [MaybeUninit<T>],
+) -> MsgResult<usize> {
+    let mut offset = 0usize;
+    let mut fd_offset = 0usize;
+
+    // In case of an error, we need to keep track of how many elements got initialized.
+    // In order to perform the necessary drops, the below loop is executed in a closure
+    // to capture errors without returning.
+    let mut last_index = 0;
+    let res = (|| {
+        for msg in &mut msgs[..] {
+            let element_size = match T::fixed_size() {
+                Some(s) => s,
+                None => simple_read::<u64>(buffer, &mut offset)? as usize,
+            };
+            // Assuming the unsafe caller gave valid FDs, this call should be safe.
+            let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
+            *msg = MaybeUninit::new(m);
+            offset += element_size;
+            fd_offset += fd_size;
+            last_index += 1;
+        }
+        Ok(())
+    })();
+
+    // Because `MaybeUninit` will not automatically call drops, we have to drop the
+    // partially initialized array manually in the case of an error.
+    if let Err(e) = res {
+        for msg in &mut msgs[..last_index] {
+            // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
+            // into a pointer, which can be used with `drop_in_place` to call the
+            // destructor without moving the element, which is impossible. This is safe
+            // because `last_index` prevents this loop from traversing into the
+            // uninitialized parts of the array.
+            drop_in_place(msg.as_mut_ptr());
+        }
+        return Err(e);
+    }
+
+    Ok(fd_offset)
+}
+
+/// Helper used by the types that write a slice of homegenously typed data.
+pub fn slice_write_helper<T: MsgOnSocket>(
+    msgs: &[T],
+    buffer: &mut [u8],
+    fds: &mut [RawFd],
+) -> MsgResult<usize> {
+    let mut offset = 0usize;
+    let mut fd_offset = 0usize;
+    for msg in msgs {
+        let element_size = match T::fixed_size() {
+            Some(s) => s,
+            None => {
+                let element_size = msg.msg_size();
+                simple_write(element_size as u64, buffer, &mut offset)?;
+                element_size as usize
+            }
+        };
+        let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?;
+        offset += element_size;
+        fd_offset += fd_size;
+    }
+
+    Ok(fd_offset)
+}
+
+impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn fixed_size() -> Option<usize> {
+        None
+    }
+
+    fn msg_size(&self) -> usize {
+        let vec_size = match T::fixed_size() {
+            Some(s) => s * self.len(),
+            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
+        };
+        size_of::<u64>() + vec_size
+    }
+
+    fn fd_count(&self) -> usize {
+        if T::uses_fd() {
+            self.iter().map(|i| i.fd_count()).sum()
+        } else {
+            0
+        }
+    }
+
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        let mut offset = 0;
+        let len = simple_read::<u64>(buffer, &mut offset)? as usize;
+        let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
+        msgs.set_len(len);
+        let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?;
+        let mut msgs = ManuallyDrop::new(msgs);
+        Ok((
+            Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()),
+            fd_count,
+        ))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        let mut offset = 0;
+        simple_write(self.len() as u64, buffer, &mut offset)?;
+        slice_write_helper(self, &mut buffer[offset..], fds)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn read_write_1_fixed() {
+        let vec = vec![1u32];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_8_fixed() {
+        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_1() {
+        let vec = vec![Some(1u64)];
+        let mut buffer = vec![0; vec.msg_size()];
+        println!("{:?}", vec.msg_size());
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_4() {
+        let vec = vec![Some(12u16), Some(0), None, None];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+}
diff --git a/msg_socket/src/msg_on_socket/tuple.rs b/msg_socket/src/msg_on_socket/tuple.rs
new file mode 100644
index 0000000..f960ce5
--- /dev/null
+++ b/msg_socket/src/msg_on_socket/tuple.rs
@@ -0,0 +1,205 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::mem::size_of;
+use std::os::unix::io::RawFd;
+
+use crate::{MsgOnSocket, MsgResult};
+
+use super::{simple_read, simple_write};
+
+// Returns the size of one part of a tuple.
+fn tuple_size_helper<T: MsgOnSocket>(v: &T) -> usize {
+    T::fixed_size().unwrap_or_else(|| v.msg_size() + size_of::<u64>())
+}
+
+unsafe fn tuple_read_helper<T: MsgOnSocket>(
+    buffer: &[u8],
+    fds: &[RawFd],
+    buffer_index: &mut usize,
+    fd_index: &mut usize,
+) -> MsgResult<T> {
+    let end = match T::fixed_size() {
+        Some(_) => buffer.len(),
+        None => {
+            let len = simple_read::<u64>(buffer, buffer_index)? as usize;
+            *buffer_index + len
+        }
+    };
+    let (v, fd_read) = T::read_from_buffer(&buffer[*buffer_index..end], &fds[*fd_index..])?;
+    *buffer_index += v.msg_size();
+    *fd_index += fd_read;
+    Ok(v)
+}
+
+fn tuple_write_helper<T: MsgOnSocket>(
+    v: &T,
+    buffer: &mut [u8],
+    fds: &mut [RawFd],
+    buffer_index: &mut usize,
+    fd_index: &mut usize,
+) -> MsgResult<()> {
+    let end = match T::fixed_size() {
+        Some(_) => buffer.len(),
+        None => {
+            let len = v.msg_size();
+            simple_write(len as u64, buffer, buffer_index)?;
+            *buffer_index + len
+        }
+    };
+    let fd_written = v.write_to_buffer(&mut buffer[*buffer_index..end], &mut fds[*fd_index..])?;
+    *buffer_index += v.msg_size();
+    *fd_index += fd_written;
+    Ok(())
+}
+
+macro_rules! tuple_impls {
+    () => {};
+    ($t: ident) => {
+        #[allow(unused_variables, non_snake_case)]
+        impl<$t: MsgOnSocket> MsgOnSocket for ($t,) {
+            fn uses_fd() -> bool {
+                $t::uses_fd()
+            }
+
+            fn fd_count(&self) -> usize {
+                self.0.fd_count()
+            }
+
+            fn fixed_size() -> Option<usize> {
+                $t::fixed_size()
+            }
+
+            fn msg_size(&self) -> usize {
+                self.0.msg_size()
+            }
+
+            unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                let (t, s) = $t::read_from_buffer(buffer, fds)?;
+                Ok(((t,), s))
+            }
+
+            fn write_to_buffer(
+                &self,
+                buffer: &mut [u8],
+                fds: &mut [RawFd],
+            ) -> MsgResult<usize> {
+                self.0.write_to_buffer(buffer, fds)
+            }
+        }
+    };
+    ($t: ident, $($ts:ident),*) => {
+        #[allow(unused_variables, non_snake_case)]
+        impl<$t: MsgOnSocket $(, $ts: MsgOnSocket)*> MsgOnSocket for ($t$(, $ts)*) {
+            fn uses_fd() -> bool {
+                $t::uses_fd() $(|| $ts::uses_fd())*
+            }
+
+            fn fd_count(&self) -> usize {
+                if Self::uses_fd() {
+                    return 0;
+                }
+                let ($t $(,$ts)*) = self;
+                $t.fd_count() $(+ $ts.fd_count())*
+            }
+
+            fn fixed_size() -> Option<usize> {
+                // Returns None if any element is not fixed size.
+                Some($t::fixed_size()? $(+ $ts::fixed_size()?)*)
+            }
+
+            fn msg_size(&self) -> usize {
+                if let Some(size) = Self::fixed_size() {
+                    return size
+                }
+
+                let ($t $(,$ts)*) = self;
+                tuple_size_helper($t) $(+ tuple_size_helper($ts))*
+            }
+
+            unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                let mut buffer_index = 0;
+                let mut fd_index = 0;
+                Ok((
+                        (
+                            tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?,
+                            $({
+                                // Dummy let used to trigger the correct number of iterations.
+                                let $ts = ();
+                                tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?
+                            },)*
+                        ),
+                        fd_index
+                ))
+            }
+
+            fn write_to_buffer(
+                &self,
+                buffer: &mut [u8],
+                fds: &mut [RawFd],
+            ) -> MsgResult<usize> {
+                let mut buffer_index = 0;
+                let mut fd_index = 0;
+                let ($t $(,$ts)*) = self;
+                tuple_write_helper($t, buffer, fds, &mut buffer_index, &mut fd_index)?;
+                $(
+                    tuple_write_helper($ts, buffer, fds, &mut buffer_index, &mut fd_index)?;
+                )*
+                Ok(fd_index)
+            }
+        }
+        tuple_impls!{ $($ts),* }
+    }
+}
+
+// Imlpement tuple for up to 8 elements.
+tuple_impls! { A, B, C, D, E, F, G, H }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn read_write_1_fixed() {
+        let tuple = (1,);
+        let mut buffer = vec![0; tuple.msg_size()];
+        tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_tuple = unsafe { <(u32,)>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(tuple, read_tuple);
+    }
+
+    #[test]
+    fn read_write_8_fixed() {
+        let tuple = (1u32, 2u8, 3u16, 4u64, 5u32, 6u16, 7u8, 8u8);
+        let mut buffer = vec![0; tuple.msg_size()];
+        tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
+
+        assert_eq!(tuple, read_tuple);
+    }
+
+    #[test]
+    fn read_write_1() {
+        let tuple = (Some(1u64),);
+        let mut buffer = vec![0; tuple.msg_size()];
+        tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
+
+        assert_eq!(tuple, read_tuple);
+    }
+
+    #[test]
+    fn read_write_4() {
+        let tuple = (Some(12u16), Some(false), None::<u8>, None::<u64>);
+        let mut buffer = vec![0; tuple.msg_size()];
+        println!("{:?}", tuple.msg_size());
+        tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
+
+        assert_eq!(tuple, read_tuple);
+    }
+}