From 787c84b51b29c0715c6d3e73aca0148b6b112440 Mon Sep 17 00:00:00 2001 From: Zach Reizner Date: Fri, 31 Jan 2020 17:17:32 -0800 Subject: sys_util: recv entire UnixSeqpacket packets into Vec This change adds the `recv_*_vec` suite of methods for getting an entire packet into a `Vec` without needing to know the packet size through some other means. TEST=cargo test -p sys_util -p msg_socket BUG=None Change-Id: Ia4f931ccb91f6de6ee2103387fd95dfad3d3d38b Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2034025 Commit-Queue: Zach Reizner Tested-by: Zach Reizner Tested-by: kokoro Auto-Submit: Zach Reizner Reviewed-by: Daniel Verkamp Reviewed-by: Stephen Barber --- msg_socket/src/lib.rs | 30 +++++------ sys_util/src/net.rs | 114 ++++++++++++++++++++++++++++++++++++++++++ sys_util/src/sock_ctrl_msg.rs | 4 ++ 3 files changed, 133 insertions(+), 15 deletions(-) diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index c6e3a38..5b9f9ce 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -145,33 +145,33 @@ pub trait MsgReceiver: AsRef { fn recv(&self) -> MsgResult { let msg_size = Self::M::msg_size(); let fd_size = Self::M::max_fd_count(); - let mut msg_buffer: Vec = vec![0; msg_size]; - let mut fd_buffer: Vec = vec![0; fd_size]; let sock: &UnixSeqpacket = self.as_ref(); - let (recv_msg_size, recv_fd_size) = { + let (msg_buffer, fd_buffer) = { if fd_size == 0 { - let size = sock - .recv(&mut msg_buffer) - .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?; - (size, 0) + ( + sock.recv_as_vec().map_err(|e| { + MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))) + })?, + vec![], + ) } else { - sock.recv_with_fds(&mut msg_buffer, &mut fd_buffer) - .map_err(MsgError::Recv)? + sock.recv_as_vec_with_fds() + .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))? } }; - if msg_size != recv_msg_size { + + if msg_size != msg_buffer.len() { return Err(MsgError::BadRecvSize { expected: msg_size, - actual: recv_msg_size, + actual: msg_buffer.len(), }); } // Safe because fd buffer is read from socket. - let (v, read_fd_size) = unsafe { - Self::M::read_from_buffer(&msg_buffer[0..recv_msg_size], &fd_buffer[0..recv_fd_size])? - }; - if recv_fd_size != read_fd_size { + let (v, read_fd_size) = + unsafe { Self::M::read_from_buffer(&msg_buffer[..], &fd_buffer[..])? }; + if fd_buffer.len() != read_fd_size { return Err(MsgError::NotExpectFd); } Ok(v) diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs index 70f975b..71ab3ee 100644 --- a/sys_util/src/net.rs +++ b/sys_util/src/net.rs @@ -16,6 +16,10 @@ use std::path::PathBuf; use std::ptr::null_mut; use std::time::Duration; +use libc::{recvfrom, MSG_PEEK, MSG_TRUNC}; + +use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT}; + // Offset of sun_path in structure sockaddr_un. fn sun_path_offset() -> usize { // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer. @@ -149,6 +153,28 @@ impl UnixSeqpacket { } } + /// Gets the number of bytes in the next packet. This blocks as if `recv` were called, + /// respecting the blocking and timeout settings of the underlying socket. + pub fn next_packet_size(&self) -> io::Result { + // This form of recvfrom doesn't modify any data because all null pointers are used. We only + // use the return value and check for errors on an FD owned by this structure. + let ret = unsafe { + recvfrom( + self.fd, + null_mut(), + 0, + MSG_TRUNC | MSG_PEEK, + null_mut(), + null_mut(), + ) + }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret as usize) + } + } + /// Write data from a given buffer to the socket fd /// /// # Arguments @@ -193,6 +219,52 @@ impl UnixSeqpacket { } } + /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size. + /// + /// # Arguments + /// * `buf` - A mut reference to a `Vec` to resize and read into. + /// + /// # Errors + /// Returns error when `libc::read` or `get_readable_bytes` failed. + pub fn recv_to_vec(&self, buf: &mut Vec) -> io::Result<()> { + let packet_size = self.next_packet_size()?; + buf.resize(packet_size, 0); + let read_bytes = self.recv(buf)?; + buf.resize(read_bytes, 0); + Ok(()) + } + + /// Read data from the socket fd to a new `Vec`. + /// + /// # Returns + /// * `vec` - A new `Vec` with the entire received packet. + /// + /// # Errors + /// Returns error when `libc::read` or `get_readable_bytes` failed. + pub fn recv_as_vec(&self) -> io::Result> { + let mut buf = Vec::new(); + self.recv_to_vec(&mut buf)?; + Ok(buf) + } + + /// Read data and fds from the socket fd to a new pair of `Vec`. + /// + /// # Returns + /// * `Vec` - A new `Vec` with the entire received packet's bytes. + /// * `Vec` - A new `Vec` with the entire received packet's fds. + /// + /// # Errors + /// Returns error when `recv_with_fds` or `get_readable_bytes` failed. + pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec, Vec)> { + let packet_size = self.next_packet_size()?; + let mut buf = vec![0; packet_size]; + let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT]; + let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?; + buf.resize(read_bytes, 0); + fd_buf.resize(read_fds, -1); + Ok((buf, fd_buf)) + } + fn set_timeout(&self, timeout: Option, kind: libc::c_int) -> io::Result<()> { let timeval = match timeout { Some(t) => { @@ -412,6 +484,7 @@ impl Drop for UnlinkUnixSeqpacketListener { mod tests { use super::*; use std::env; + use std::io::ErrorKind; use std::path::PathBuf; fn tmpdir() -> PathBuf { @@ -584,4 +657,45 @@ mod tests { assert_eq!(s1.get_readable_bytes().unwrap(), 0); assert_eq!(s2.get_readable_bytes().unwrap(), 0); } + + #[test] + fn unix_seqpacket_next_packet_size() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + assert_eq!(s2.next_packet_size().unwrap(), 5); + s1.set_read_timeout(Some(Duration::from_micros(1))) + .expect("failed to set read timeout"); + assert_eq!( + s1.next_packet_size().unwrap_err().kind(), + ErrorKind::WouldBlock + ); + drop(s2); + assert_eq!( + s1.next_packet_size().unwrap_err().kind(), + ErrorKind::ConnectionReset + ); + } + + #[test] + fn unix_seqpacket_recv_to_vec() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + let recv_data = &mut vec![]; + s2.recv_to_vec(recv_data).expect("failed to recv data"); + assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]); + } + + #[test] + fn unix_seqpacket_recv_as_vec() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + let recv_data = s2.recv_as_vec().expect("failed to recv data"); + assert_eq!(recv_data, vec![0, 1, 2, 3, 4]); + } } diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs index 13b9b0c..d4b953b 100644 --- a/sys_util/src/sock_ctrl_msg.rs +++ b/sys_util/src/sock_ctrl_msg.rs @@ -213,6 +213,9 @@ fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(u Ok((total_read as usize, in_fds_count)) } +/// The maximum number of FDs that can be sent in a single send. +pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253; + /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and /// `recvmsg`. pub trait ScmSocket { @@ -292,6 +295,7 @@ impl ScmSocket for UnixStream { self.as_raw_fd() } } + impl ScmSocket for UnixSeqpacket { fn socket_fd(&self) -> RawFd { self.as_raw_fd() -- cgit 1.4.1