summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket/slice.rs
blob: 7b6ef28f5aa3dd7fc9e6422b36011da1dcc1afad (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::os::unix::io::RawFd;
use std::ptr::drop_in_place;

use crate::{MsgOnSocket, MsgResult};

use super::{simple_read, simple_write};

/// Helper used by the types that read a slice of homegenously typed data.
///
/// # Safety
/// This function has the same safety requirements as `T::read_from_buffer`, with the additional
/// requirements that the `msgs` are only used on success of this function
pub unsafe fn slice_read_helper<T: MsgOnSocket>(
    buffer: &[u8],
    fds: &[RawFd],
    msgs: &mut [MaybeUninit<T>],
) -> MsgResult<usize> {
    let mut offset = 0usize;
    let mut fd_offset = 0usize;

    // In case of an error, we need to keep track of how many elements got initialized.
    // In order to perform the necessary drops, the below loop is executed in a closure
    // to capture errors without returning.
    let mut last_index = 0;
    let res = (|| {
        for msg in &mut msgs[..] {
            let element_size = match T::fixed_size() {
                Some(s) => s,
                None => simple_read::<u64>(buffer, &mut offset)? as usize,
            };
            // Assuming the unsafe caller gave valid FDs, this call should be safe.
            let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
            *msg = MaybeUninit::new(m);
            offset += element_size;
            fd_offset += fd_size;
            last_index += 1;
        }
        Ok(())
    })();

    // Because `MaybeUninit` will not automatically call drops, we have to drop the
    // partially initialized array manually in the case of an error.
    if let Err(e) = res {
        for msg in &mut msgs[..last_index] {
            // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
            // into a pointer, which can be used with `drop_in_place` to call the
            // destructor without moving the element, which is impossible. This is safe
            // because `last_index` prevents this loop from traversing into the
            // uninitialized parts of the array.
            drop_in_place(msg.as_mut_ptr());
        }
        return Err(e);
    }

    Ok(fd_offset)
}

/// Helper used by the types that write a slice of homegenously typed data.
pub fn slice_write_helper<T: MsgOnSocket>(
    msgs: &[T],
    buffer: &mut [u8],
    fds: &mut [RawFd],
) -> MsgResult<usize> {
    let mut offset = 0usize;
    let mut fd_offset = 0usize;
    for msg in msgs {
        let element_size = match T::fixed_size() {
            Some(s) => s,
            None => {
                let element_size = msg.msg_size();
                simple_write(element_size as u64, buffer, &mut offset)?;
                element_size as usize
            }
        };
        let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?;
        offset += element_size;
        fd_offset += fd_size;
    }

    Ok(fd_offset)
}

impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
    fn uses_fd() -> bool {
        T::uses_fd()
    }

    fn fixed_size() -> Option<usize> {
        None
    }

    fn msg_size(&self) -> usize {
        let vec_size = match T::fixed_size() {
            Some(s) => s * self.len(),
            None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
        };
        size_of::<u64>() + vec_size
    }

    fn fd_count(&self) -> usize {
        if T::uses_fd() {
            self.iter().map(|i| i.fd_count()).sum()
        } else {
            0
        }
    }

    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
        let mut offset = 0;
        let len = simple_read::<u64>(buffer, &mut offset)? as usize;
        let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
        msgs.set_len(len);
        let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?;
        let mut msgs = ManuallyDrop::new(msgs);
        Ok((
            Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()),
            fd_count,
        ))
    }

    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
        let mut offset = 0;
        simple_write(self.len() as u64, buffer, &mut offset)?;
        slice_write_helper(self, &mut buffer[offset..], fds)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn read_write_1_fixed() {
        let vec = vec![1u32];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }

    #[test]
    fn read_write_8_fixed() {
        let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;
        assert_eq!(vec, read_vec);
    }

    #[test]
    fn read_write_1() {
        let vec = vec![Some(1u64)];
        let mut buffer = vec![0; vec.msg_size()];
        println!("{:?}", vec.msg_size());
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }

    #[test]
    fn read_write_4() {
        let vec = vec![Some(12u16), Some(0), None, None];
        let mut buffer = vec![0; vec.msg_size()];
        vec.write_to_buffer(&mut buffer, &mut []).unwrap();
        let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
            .unwrap()
            .0;

        assert_eq!(vec, read_vec);
    }
}