summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'msg_socket/src/msg_on_socket.rs')
-rw-r--r--msg_socket/src/msg_on_socket.rs123
1 files changed, 42 insertions, 81 deletions
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
index 82ee9a8..2101192 100644
--- a/msg_socket/src/msg_on_socket.rs
+++ b/msg_socket/src/msg_on_socket.rs
@@ -2,17 +2,20 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
+mod slice;
+mod tuple;
+
 use std::fmt::{self, Display};
 use std::fs::File;
 use std::mem::{size_of, transmute_copy, MaybeUninit};
 use std::net::{TcpListener, TcpStream, UdpSocket};
 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
-use std::ptr::drop_in_place;
 use std::result;
 use std::sync::Arc;
 
 use data_model::*;
+use slice::{slice_read_helper, slice_write_helper};
 use sync::Mutex;
 use sys_util::{Error as SysError, EventFd};
 
@@ -286,20 +289,6 @@ rawfd_impl!(UdpSocket);
 rawfd_impl!(UnixListener);
 rawfd_impl!(UnixDatagram);
 
-// Converts a slice into an array of fixed size inferred from by the return value. Panics if the
-// slice is too small, but will tolerate slices that are too large.
-fn slice_to_array<T, O>(s: &[T]) -> O
-where
-    T: Copy,
-    O: Default + AsMut<[T]>,
-{
-    let mut o = O::default();
-    let o_slice = o.as_mut();
-    let len = o_slice.len();
-    o_slice.copy_from_slice(&s[..len]);
-    o
-}
-
 // usize could be different sizes on different targets. We always use u64.
 impl MsgOnSocket for usize {
     fn msg_size(&self) -> usize {
@@ -384,6 +373,35 @@ le_impl!(Le16, u16);
 le_impl!(Le32, u32);
 le_impl!(Le64, u64);
 
+fn simple_read<T: MsgOnSocket>(buffer: &[u8], offset: &mut usize) -> MsgResult<T> {
+    assert!(!T::uses_fd());
+    // Safety for T::read_from_buffer depends on the given FDs being valid, but we pass no FDs.
+    let (v, _) = unsafe { T::read_from_buffer(&buffer[*offset..], &[])? };
+    *offset += v.msg_size();
+    Ok(v)
+}
+
+fn simple_write<T: MsgOnSocket>(v: T, buffer: &mut [u8], offset: &mut usize) -> MsgResult<()> {
+    assert!(!T::uses_fd());
+    v.write_to_buffer(&mut buffer[*offset..], &mut [])?;
+    *offset += v.msg_size();
+    Ok(())
+}
+
+// Converts a slice into an array of fixed size inferred from by the return value. Panics if the
+// slice is too small, but will tolerate slices that are too large.
+fn slice_to_array<T, O>(s: &[T]) -> O
+where
+    T: Copy,
+    O: Default + AsMut<[T]>,
+{
+    let mut o = O::default();
+    let o_slice = o.as_mut();
+    let len = o_slice.len();
+    o_slice.copy_from_slice(&s[..len]);
+    o
+}
+
 macro_rules! array_impls {
     ($N:expr, $t: ident $($ts:ident)*)
     => {
@@ -417,46 +435,7 @@ macro_rules! array_impls {
                 // themselves don't require initializing.
                 let mut msgs: [MaybeUninit<T>; $N] =  MaybeUninit::uninit().assume_init();
 
-                let mut offset = 0usize;
-                let mut fd_offset = 0usize;
-
-                // In case of an error, we need to keep track of how many elements got initialized.
-                // In order to perform the necessary drops, the below loop is executed in a closure
-                // to capture errors without returning.
-                let mut last_index = 0;
-                let res = (|| {
-                    for msg in &mut msgs[..] {
-                        let element_size = match T::fixed_size() {
-                            Some(s) => s,
-                            None => {
-                                let (element_size, _) = u64::read_from_buffer(&buffer[offset..], &[])?;
-                                offset += element_size.msg_size();
-                                element_size as usize
-                            }
-                        };
-                        let (m, fd_size) =
-                            T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
-                        *msg = MaybeUninit::new(m);
-                        offset += element_size;
-                        fd_offset += fd_size;
-                        last_index += 1;
-                    }
-                    Ok(())
-                })();
-
-                // Because `MaybeUninit` will not automatically call drops, we have to drop the
-                // partially initialized array manually in the case of an error.
-                if let Err(e) = res {
-                    for msg in &mut msgs[..last_index] {
-                        // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
-                        // into a pointer, which can be used with `drop_in_place` to call the
-                        // destructor without moving the element, which is impossible. This is safe
-                        // because `last_index` prevents this loop from traversing into the
-                        // uninitialized parts of the array.
-                        drop_in_place(msg.as_mut_ptr());
-                    }
-                    return Err(e)
-                }
+                let fd_count = slice_read_helper(buffer, fds, &mut msgs)?;
 
                 // Also taken from the canonical example, we initialized every member of the array
                 // in the first loop of this function, so it is safe to `transmute_copy` the array
@@ -466,7 +445,7 @@ macro_rules! array_impls {
                 // Because this function operates on generic data, the type is "dependently-sized"
                 // and so the compiler will not check that the size of the input and output match.
                 // See this issue for details: https://github.com/rust-lang/rust/issues/61956
-                Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_offset))
+                Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_count))
             }
 
             fn write_to_buffer(
@@ -474,25 +453,7 @@ macro_rules! array_impls {
                 buffer: &mut [u8],
                 fds: &mut [RawFd],
                 ) -> MsgResult<usize> {
-                let mut offset = 0usize;
-                let mut fd_offset = 0usize;
-                for idx in 0..$N {
-                    let element_size = match T::fixed_size() {
-                        Some(s) => s,
-                        None => {
-                            let element_size = self[idx].msg_size() as u64;
-                            element_size.write_to_buffer(&mut buffer[offset..], &mut [])?;
-                            offset += element_size.msg_size();
-                            element_size as usize
-                        }
-                    };
-                    let fd_size = self[idx].write_to_buffer(&mut buffer[offset..],
-                                                            &mut fds[fd_offset..])?;
-                    offset += element_size;
-                    fd_offset += fd_size;
-                }
-
-                Ok(fd_offset)
+                slice_write_helper(self, buffer, fds)
             }
         }
         #[cfg(test)]
@@ -507,7 +468,7 @@ macro_rules! array_impls {
                 array.write_to_buffer(&mut buffer, &mut []).unwrap();
                 let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
 
-                assert_eq!(array, read_array);
+                assert!(array.iter().eq(read_array.iter()));
             }
 
             #[test]
@@ -518,7 +479,7 @@ macro_rules! array_impls {
                 array.write_to_buffer(&mut buffer, &mut []).unwrap();
                 let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
 
-                assert_eq!(array, read_array);
+                assert!(array.iter().eq(read_array.iter()));
             }
         }
         array_impls!(($N - 1), $($ts)*);
@@ -527,9 +488,9 @@ macro_rules! array_impls {
 }
 
 array_impls! {
-    32, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16
+    64, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16
         tmp17 tmp18 tmp19 tmp20 tmp21 tmp22 tmp23 tmp24 tmp25 tmp26 tmp27 tmp28 tmp29 tmp30 tmp31
-        tmp32
+        tmp32 tmp33 tmp34 tmp35 tmp36 tmp37 tmp38 tmp39 tmp40 tmp41 tmp42 tmp43 tmp44 tmp45 tmp46
+        tmp47 tmp48 tmp49 tmp50 tmp51 tmp52 tmp53 tmp54 tmp55 tmp56 tmp57 tmp58 tmp59 tmp60 tmp61
+        tmp62 tmp63 tmp64
 }
-
-// TODO(jkwang) Define MsgOnSocket for tuple?