summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket/slice.rs
diff options
context:
space:
mode:
Diffstat (limited to 'msg_socket/src/msg_on_socket/slice.rs')
-rw-r--r--msg_socket/src/msg_on_socket/slice.rs184
1 files changed, 184 insertions, 0 deletions
diff --git a/msg_socket/src/msg_on_socket/slice.rs b/msg_socket/src/msg_on_socket/slice.rs
new file mode 100644
index 0000000..7b6ef28
--- /dev/null
+++ b/msg_socket/src/msg_on_socket/slice.rs
@@ -0,0 +1,184 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::mem::{size_of, ManuallyDrop, MaybeUninit};
+use std::os::unix::io::RawFd;
+use std::ptr::drop_in_place;
+
+use crate::{MsgOnSocket, MsgResult};
+
+use super::{simple_read, simple_write};
+
+/// Helper used by the types that read a slice of homegenously typed data.
+///
+/// # Safety
+/// This function has the same safety requirements as `T::read_from_buffer`, with the additional
+/// requirements that the `msgs` are only used on success of this function
+pub unsafe fn slice_read_helper<T: MsgOnSocket>(
+    buffer: &[u8],
+    fds: &[RawFd],
+    msgs: &mut [MaybeUninit<T>],
+) -> MsgResult<usize> {
+    let mut offset = 0usize;
+    let mut fd_offset = 0usize;
+
+    // In case of an error, we need to keep track of how many elements got initialized.
+    // In order to perform the necessary drops, the below loop is executed in a closure
+    // to capture errors without returning.
+    let mut last_index = 0;
+    let res = (|| {
+        for msg in &mut msgs[..] {
+            let element_size = match T::fixed_size() {
+                Some(s) => s,
+                None => simple_read::<u64>(buffer, &mut offset)? as usize,
+            };
+            // Assuming the unsafe caller gave valid FDs, this call should be safe.
+            let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
+            *msg = MaybeUninit::new(m);
+            offset += element_size;
+            fd_offset += fd_size;
+            last_index += 1;
+        }
+        Ok(())
+    })();
+
+    // Because `MaybeUninit` will not automatically call drops, we have to drop the
+    // partially initialized array manually in the case of an error.
+    if let Err(e) = res {
+        for msg in &mut msgs[..last_index] {
+            // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
+            // into a pointer, which can be used with `drop_in_place` to call the
+            // destructor without moving the element, which is impossible. This is safe
+            // because `last_index` prevents this loop from traversing into the
+            // uninitialized parts of the array.
+            drop_in_place(msg.as_mut_ptr());
+        }
+        return Err(e);
+    }
+
+    Ok(fd_offset)
+}
+
+/// Helper used by the types that write a slice of homegenously typed data.
+pub fn slice_write_helper<T: MsgOnSocket>(
+    msgs: &[T],
+    buffer: &mut [u8],
+    fds: &mut [RawFd],
+) -> MsgResult<usize> {
+    let mut offset = 0usize;
+    let mut fd_offset = 0usize;
+    for msg in msgs {
+        let element_size = match T::fixed_size() {
+            Some(s) => s,
+            None => {
+                let element_size = msg.msg_size();
+                simple_write(element_size as u64, buffer, &mut offset)?;
+                element_size as usize
+            }
+        };
+        let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?;
+        offset += element_size;
+        fd_offset += fd_size;
+    }
+
+    Ok(fd_offset)
+}
+
+impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
+    fn uses_fd() -> bool {
+        T::uses_fd()
+    }
+
+    fn fixed_size() -> Option<usize> {
+        None
+    }
+
+    fn msg_size(&self) -> usize {
+        let vec_size = match T::fixed_size() {
+            Some(s) => s * self.len(),
+            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
+        };
+        size_of::<u64>() + vec_size
+    }
+
+    fn fd_count(&self) -> usize {
+        if T::uses_fd() {
+            self.iter().map(|i| i.fd_count()).sum()
+        } else {
+            0
+        }
+    }
+
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        let mut offset = 0;
+        let len = simple_read::<u64>(buffer, &mut offset)? as usize;
+        let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
+        msgs.set_len(len);
+        let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?;
+        let mut msgs = ManuallyDrop::new(msgs);
+        Ok((
+            Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()),
+            fd_count,
+        ))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        let mut offset = 0;
+        simple_write(self.len() as u64, buffer, &mut offset)?;
+        slice_write_helper(self, &mut buffer[offset..], fds)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn read_write_1_fixed() {
+        let vec = vec![1u32];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_8_fixed() {
+        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_1() {
+        let vec = vec![Some(1u64)];
+        let mut buffer = vec![0; vec.msg_size()];
+        println!("{:?}", vec.msg_size());
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+
+    #[test]
+    fn read_write_4() {
+        let vec = vec![Some(12u16), Some(0), None, None];
+        let mut buffer = vec![0; vec.msg_size()];
+        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
+        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
+            .unwrap()
+            .0;
+
+        assert_eq!(vec, read_vec);
+    }
+}