From 5f1a64892b714885f6c7405084a390467c03201a Mon Sep 17 00:00:00 2001 From: Zach Reizner Date: Mon, 27 Apr 2020 12:52:08 -0700 Subject: msg_socket: implement MsgOnSocket for Vec and tuples These container types are similar to arrays except tuples have heterogeneous data types and Vec has a dynamic number of elements. BUG=None TEST=cargo test -p msg_socket Change-Id: I2cbbaeb7f13b7700294ac50751530510098ba16d Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2168588 Reviewed-by: Daniel Verkamp Reviewed-by: Dylan Reid Tested-by: kokoro Tested-by: Zach Reizner Commit-Queue: Zach Reizner --- msg_socket/src/msg_on_socket/tuple.rs | 205 ++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 msg_socket/src/msg_on_socket/tuple.rs (limited to 'msg_socket/src/msg_on_socket/tuple.rs') diff --git a/msg_socket/src/msg_on_socket/tuple.rs b/msg_socket/src/msg_on_socket/tuple.rs new file mode 100644 index 0000000..f960ce5 --- /dev/null +++ b/msg_socket/src/msg_on_socket/tuple.rs @@ -0,0 +1,205 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::mem::size_of; +use std::os::unix::io::RawFd; + +use crate::{MsgOnSocket, MsgResult}; + +use super::{simple_read, simple_write}; + +// Returns the size of one part of a tuple. +fn tuple_size_helper(v: &T) -> usize { + T::fixed_size().unwrap_or_else(|| v.msg_size() + size_of::()) +} + +unsafe fn tuple_read_helper( + buffer: &[u8], + fds: &[RawFd], + buffer_index: &mut usize, + fd_index: &mut usize, +) -> MsgResult { + let end = match T::fixed_size() { + Some(_) => buffer.len(), + None => { + let len = simple_read::(buffer, buffer_index)? as usize; + *buffer_index + len + } + }; + let (v, fd_read) = T::read_from_buffer(&buffer[*buffer_index..end], &fds[*fd_index..])?; + *buffer_index += v.msg_size(); + *fd_index += fd_read; + Ok(v) +} + +fn tuple_write_helper( + v: &T, + buffer: &mut [u8], + fds: &mut [RawFd], + buffer_index: &mut usize, + fd_index: &mut usize, +) -> MsgResult<()> { + let end = match T::fixed_size() { + Some(_) => buffer.len(), + None => { + let len = v.msg_size(); + simple_write(len as u64, buffer, buffer_index)?; + *buffer_index + len + } + }; + let fd_written = v.write_to_buffer(&mut buffer[*buffer_index..end], &mut fds[*fd_index..])?; + *buffer_index += v.msg_size(); + *fd_index += fd_written; + Ok(()) +} + +macro_rules! tuple_impls { + () => {}; + ($t: ident) => { + #[allow(unused_variables, non_snake_case)] + impl<$t: MsgOnSocket> MsgOnSocket for ($t,) { + fn uses_fd() -> bool { + $t::uses_fd() + } + + fn fd_count(&self) -> usize { + self.0.fd_count() + } + + fn fixed_size() -> Option { + $t::fixed_size() + } + + fn msg_size(&self) -> usize { + self.0.msg_size() + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + let (t, s) = $t::read_from_buffer(buffer, fds)?; + Ok(((t,), s)) + } + + fn write_to_buffer( + &self, + buffer: &mut [u8], + fds: &mut [RawFd], + ) -> MsgResult { + self.0.write_to_buffer(buffer, fds) + } + } + }; + ($t: ident, $($ts:ident),*) => { + #[allow(unused_variables, non_snake_case)] + impl<$t: MsgOnSocket $(, $ts: MsgOnSocket)*> MsgOnSocket for ($t$(, $ts)*) { + fn uses_fd() -> bool { + $t::uses_fd() $(|| $ts::uses_fd())* + } + + fn fd_count(&self) -> usize { + if Self::uses_fd() { + return 0; + } + let ($t $(,$ts)*) = self; + $t.fd_count() $(+ $ts.fd_count())* + } + + fn fixed_size() -> Option { + // Returns None if any element is not fixed size. + Some($t::fixed_size()? $(+ $ts::fixed_size()?)*) + } + + fn msg_size(&self) -> usize { + if let Some(size) = Self::fixed_size() { + return size + } + + let ($t $(,$ts)*) = self; + tuple_size_helper($t) $(+ tuple_size_helper($ts))* + } + + unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { + let mut buffer_index = 0; + let mut fd_index = 0; + Ok(( + ( + tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?, + $({ + // Dummy let used to trigger the correct number of iterations. + let $ts = (); + tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)? + },)* + ), + fd_index + )) + } + + fn write_to_buffer( + &self, + buffer: &mut [u8], + fds: &mut [RawFd], + ) -> MsgResult { + let mut buffer_index = 0; + let mut fd_index = 0; + let ($t $(,$ts)*) = self; + tuple_write_helper($t, buffer, fds, &mut buffer_index, &mut fd_index)?; + $( + tuple_write_helper($ts, buffer, fds, &mut buffer_index, &mut fd_index)?; + )* + Ok(fd_index) + } + } + tuple_impls!{ $($ts),* } + } +} + +// Imlpement tuple for up to 8 elements. +tuple_impls! { A, B, C, D, E, F, G, H } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn read_write_1_fixed() { + let tuple = (1,); + let mut buffer = vec![0; tuple.msg_size()]; + tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_tuple = unsafe { <(u32,)>::read_from_buffer(&buffer, &[]) } + .unwrap() + .0; + + assert_eq!(tuple, read_tuple); + } + + #[test] + fn read_write_8_fixed() { + let tuple = (1u32, 2u8, 3u16, 4u64, 5u32, 6u16, 7u8, 8u8); + let mut buffer = vec![0; tuple.msg_size()]; + tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; + + assert_eq!(tuple, read_tuple); + } + + #[test] + fn read_write_1() { + let tuple = (Some(1u64),); + let mut buffer = vec![0; tuple.msg_size()]; + tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; + + assert_eq!(tuple, read_tuple); + } + + #[test] + fn read_write_4() { + let tuple = (Some(12u16), Some(false), None::, None::); + let mut buffer = vec![0; tuple.msg_size()]; + println!("{:?}", tuple.msg_size()); + tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); + let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; + + assert_eq!(tuple, read_tuple); + } +} -- cgit 1.4.1