summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-06-07 09:40:26 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:44 +0000
commitfd20697db144fc6e4dbd06856efb61abba236648 (patch)
tree1b969aefe4f55029585ae067d458f8a528b001bc
parenta5d8b143a440ab8def7f1bcb8438279df30d4d21 (diff)
downloadcrosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar.gz
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar.bz2
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar.lz
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar.xz
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.tar.zst
crosvm-fd20697db144fc6e4dbd06856efb61abba236648.zip
msg_socket: impl MsgOnSocket for Cow<[T]>
This is unlikely to be directly useful, but it will be helpful for
Vec-based types like String and PathBuf to delegate to this
implementation for write_to_buffer, since they can't delegate to Vec's
without copying.
-rw-r--r--msg_socket/src/msg_on_socket/slice.rs112
1 files changed, 108 insertions, 4 deletions
diff --git a/msg_socket/src/msg_on_socket/slice.rs b/msg_socket/src/msg_on_socket/slice.rs
index 7b6ef28..471e487 100644
--- a/msg_socket/src/msg_on_socket/slice.rs
+++ b/msg_socket/src/msg_on_socket/slice.rs
@@ -2,6 +2,7 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
+use std::borrow::Cow;
 use std::mem::{size_of, ManuallyDrop, MaybeUninit};
 use std::os::unix::io::RawFd;
 use std::ptr::drop_in_place;
@@ -85,6 +86,39 @@ pub fn slice_write_helper<T: MsgOnSocket>(
     Ok(fd_offset)
 }
 
+impl<'a, T: MsgOnSocket + Clone> MsgOnSocket for Cow<'a, [T]> {
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn msg_size(&self) -> usize {
+        let slice_size = match T::fixed_size() {
+            Some(s) => s * self.len(),
+            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
+        };
+        size_of::<u64>() + slice_size
+    }
+
+    fn fd_count(&self) -> usize {
+        if T::uses_fd() {
+            self.iter().map(MsgOnSocket::fd_count).sum()
+        } else {
+            0
+        }
+    }
+
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        let (vec, fd_count) = Vec::read_from_buffer(buffer, fds)?;
+        Ok((Self::Owned(vec), fd_count))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        let mut offset = 0;
+        simple_write(self.len() as u64, buffer, &mut offset)?;
+        slice_write_helper(self, &mut buffer[offset..], fds)
+    }
+}
+
 impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
     fn uses_fd() -> bool {
         T::uses_fd()
@@ -135,7 +169,7 @@ mod tests {
     use super::*;
 
     #[test]
-    fn read_write_1_fixed() {
+    fn vec_read_write_1_fixed() {
         let vec = vec![1u32];
         let mut buffer = vec![0; vec.msg_size()];
         vec.write_to_buffer(&mut buffer, &mut []).unwrap();
@@ -147,7 +181,7 @@ mod tests {
     }
 
     #[test]
-    fn read_write_8_fixed() {
+    fn vec_read_write_8_fixed() {
         let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
         let mut buffer = vec![0; vec.msg_size()];
         vec.write_to_buffer(&mut buffer, &mut []).unwrap();
@@ -158,7 +192,7 @@ mod tests {
     }
 
     #[test]
-    fn read_write_1() {
+    fn vec_read_write_1() {
         let vec = vec![Some(1u64)];
         let mut buffer = vec![0; vec.msg_size()];
         println!("{:?}", vec.msg_size());
@@ -171,7 +205,7 @@ mod tests {
     }
 
     #[test]
-    fn read_write_4() {
+    fn vec_read_write_4() {
         let vec = vec![Some(12u16), Some(0), None, None];
         let mut buffer = vec![0; vec.msg_size()];
         vec.write_to_buffer(&mut buffer, &mut []).unwrap();
@@ -181,4 +215,74 @@ mod tests {
 
         assert_eq!(vec, read_vec);
     }
+
+    #[test]
+    fn cow_vec_equiv() {
+        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
+
+        let mut vec_buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut vec_buffer, &mut []).unwrap();
+
+        let mut cow_borrowed_buffer = vec![0; vec.msg_size()];
+        let cow_borrowed = Cow::Borrowed(&vec);
+        cow_borrowed
+            .write_to_buffer(&mut cow_borrowed_buffer, &mut [])
+            .unwrap();
+
+        let mut cow_owned_buffer = vec![0; vec.msg_size()];
+        let cow_owned: Cow<[_]> = Cow::Owned(vec);
+        cow_owned
+            .write_to_buffer(&mut cow_owned_buffer, &mut [])
+            .unwrap();
+
+        assert_eq!(cow_borrowed_buffer, vec_buffer);
+        assert_eq!(cow_owned_buffer, vec_buffer);
+    }
+
+    #[test]
+    fn cow_read_write_1_fixed() {
+        let cow = Cow::Borrowed(&[1u32][..]);
+        let mut buffer = vec![0; cow.msg_size()];
+        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_cow = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(cow, read_cow);
+    }
+
+    #[test]
+    fn cow_read_write_8_fixed() {
+        let cow = Cow::Borrowed(&[1u16, 1, 3, 5, 8, 13, 21, 34][..]);
+        let mut buffer = vec![0; cow.msg_size()];
+        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_cow = unsafe { <Cow<[u16]>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+        assert_eq!(cow, read_cow);
+    }
+
+    #[test]
+    fn cow_read_write_1() {
+        let cow = Cow::Borrowed(&[Some(1u64)][..]);
+        let mut buffer = vec![0; cow.msg_size()];
+        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_cow = unsafe { <Cow<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(cow, read_cow);
+    }
+
+    #[test]
+    fn cow_read_write_4() {
+        let cow = Cow::Borrowed(&[Some(12u16), Some(0), None, None][..]);
+        let mut buffer = vec![0; cow.msg_size()];
+        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_cow = unsafe { <Cow<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(cow, read_cow);
+    }
 }