summary refs log blame commit diff
path: root/msg_socket/src/msg_on_socket/slice.rs
blob: 471e487d10f98ca71318ec90b3d407c9271e078e (plain) (tree)
1
2
3
4
5



                                                                         
                     


















































































                                                                                                
































                                                                                          

















































                                                                                          
                                 










                                                                            
                                 









                                                                            
                           











                                                                          
                           








                                                                          





































































                                                                              
 
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use std::borrow::Cow;
use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::os::unix::io::RawFd;
use std::ptr::drop_in_place;

use crate::{MsgOnSocket, MsgResult};

use super::{simple_read, simple_write};

/// Helper used by the types that read a slice of homegenously typed data.
///
/// # Safety
/// This function has the same safety requirements as `T::read_from_buffer`, with the additional
/// requirements that the `msgs` are only used on success of this function
pub unsafe fn slice_read_helper<T: MsgOnSocket>(
    buffer: &[u8],
    fds: &[RawFd],
    msgs: &mut [MaybeUninit<T>],
) -> MsgResult<usize> {
    let mut offset = 0usize;
    let mut fd_offset = 0usize;

    // In case of an error, we need to keep track of how many elements got initialized.
    // In order to perform the necessary drops, the below loop is executed in a closure
    // to capture errors without returning.
    let mut last_index = 0;
    let res = (|| {
        for msg in &mut msgs[..] {
            let element_size = match T::fixed_size() {
                Some(s) => s,
                None => simple_read::<u64>(buffer, &mut offset)? as usize,
            };
            // Assuming the unsafe caller gave valid FDs, this call should be safe.
            let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
            *msg = MaybeUninit::new(m);
            offset += element_size;
            fd_offset += fd_size;
            last_index += 1;
        }
        Ok(())
    })();

    // Because `MaybeUninit` will not automatically call drops, we have to drop the
    // partially initialized array manually in the case of an error.
    if let Err(e) = res {
        for msg in &mut msgs[..last_index] {
            // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
            // into a pointer, which can be used with `drop_in_place` to call the
            // destructor without moving the element, which is impossible. This is safe
            // because `last_index` prevents this loop from traversing into the
            // uninitialized parts of the array.
            drop_in_place(msg.as_mut_ptr());
        }
        return Err(e);
    }

    Ok(fd_offset)
}

/// Helper used by the types that write a slice of homegenously typed data.
pub fn slice_write_helper<T: MsgOnSocket>(
    msgs: &[T],
    buffer: &mut [u8],
    fds: &mut [RawFd],
) -> MsgResult<usize> {
    let mut offset = 0usize;
    let mut fd_offset = 0usize;
    for msg in msgs {
        let element_size = match T::fixed_size() {
            Some(s) => s,
            None => {
                let element_size = msg.msg_size();
                simple_write(element_size as u64, buffer, &mut offset)?;
                element_size as usize
            }
        };
        let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?;
        offset += element_size;
        fd_offset += fd_size;
    }

    Ok(fd_offset)
}

impl<'a, T: MsgOnSocket + Clone> MsgOnSocket for Cow<'a, [T]> {
    fn uses_fd() -> bool {
        T::uses_fd()
    }

    fn msg_size(&self) -> usize {
        let slice_size = match T::fixed_size() {
            Some(s) => s * self.len(),
            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
        };
        size_of::<u64>() + slice_size
    }

    fn fd_count(&self) -> usize {
        if T::uses_fd() {
            self.iter().map(MsgOnSocket::fd_count).sum()
        } else {
            0
        }
    }

    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
        let (vec, fd_count) = Vec::read_from_buffer(buffer, fds)?;
        Ok((Self::Owned(vec), fd_count))
    }

    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
        let mut offset = 0;
        simple_write(self.len() as u64, buffer, &mut offset)?;
        slice_write_helper(self, &mut buffer[offset..], fds)
    }
}

impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
    fn uses_fd() -> bool {
        T::uses_fd()
    }

    fn fixed_size() -> Option<usize> {
        None
    }

    fn msg_size(&self) -> usize {
        let vec_size = match T::fixed_size() {
            Some(s) => s * self.len(),
            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
        };
        size_of::<u64>() + vec_size
    }

    fn fd_count(&self) -> usize {
        if T::uses_fd() {
            self.iter().map(|i| i.fd_count()).sum()
        } else {
            0
        }
    }

    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
        let mut offset = 0;
        let len = simple_read::<u64>(buffer, &mut offset)? as usize;
        let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
        msgs.set_len(len);
        let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?;
        let mut msgs = ManuallyDrop::new(msgs);
        Ok((
            Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()),
            fd_count,
        ))
    }

    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
        let mut offset = 0;
        simple_write(self.len() as u64, buffer, &mut offset)?;
        slice_write_helper(self, &mut buffer[offset..], fds)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn vec_read_write_1_fixed() {
        let vec = vec![1u32];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }

    #[test]
    fn vec_read_write_8_fixed() {
        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;
        assert_eq!(vec, read_vec);
    }

    #[test]
    fn vec_read_write_1() {
        let vec = vec![Some(1u64)];
        let mut buffer = vec![0; vec.msg_size()];
        println!("{:?}", vec.msg_size());
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }

    #[test]
    fn vec_read_write_4() {
        let vec = vec![Some(12u16), Some(0), None, None];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }

    #[test]
    fn cow_vec_equiv() {
        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];

        let mut vec_buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut vec_buffer, &mut []).unwrap();

        let mut cow_borrowed_buffer = vec![0; vec.msg_size()];
        let cow_borrowed = Cow::Borrowed(&vec);
        cow_borrowed
            .write_to_buffer(&mut cow_borrowed_buffer, &mut [])
            .unwrap();

        let mut cow_owned_buffer = vec![0; vec.msg_size()];
        let cow_owned: Cow<[_]> = Cow::Owned(vec);
        cow_owned
            .write_to_buffer(&mut cow_owned_buffer, &mut [])
            .unwrap();

        assert_eq!(cow_borrowed_buffer, vec_buffer);
        assert_eq!(cow_owned_buffer, vec_buffer);
    }

    #[test]
    fn cow_read_write_1_fixed() {
        let cow = Cow::Borrowed(&[1u32][..]);
        let mut buffer = vec![0; cow.msg_size()];
        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_cow = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(cow, read_cow);
    }

    #[test]
    fn cow_read_write_8_fixed() {
        let cow = Cow::Borrowed(&[1u16, 1, 3, 5, 8, 13, 21, 34][..]);
        let mut buffer = vec![0; cow.msg_size()];
        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_cow = unsafe { <Cow<[u16]>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;
        assert_eq!(cow, read_cow);
    }

    #[test]
    fn cow_read_write_1() {
        let cow = Cow::Borrowed(&[Some(1u64)][..]);
        let mut buffer = vec![0; cow.msg_size()];
        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_cow = unsafe { <Cow<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(cow, read_cow);
    }

    #[test]
    fn cow_read_write_4() {
        let cow = Cow::Borrowed(&[Some(12u16), Some(0), None, None][..]);
        let mut buffer = vec![0; cow.msg_size()];
        cow.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_cow = unsafe { <Cow<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(cow, read_cow);
    }
}