summary refs log tree commit diff
diff options
context:
space:
mode:
authorZach Reizner <zachr@google.com>2020-01-31 17:17:32 -0800
committerCommit Bot <commit-bot@chromium.org>2020-02-06 21:56:37 +0000
commit787c84b51b29c0715c6d3e73aca0148b6b112440 (patch)
tree182364448e8b47f70fb74131502043361daca3b8
parent4441c01124a30b7037267fdc74aeee4b6eff111a (diff)
downloadcrosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar.gz
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar.bz2
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar.lz
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar.xz
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.tar.zst
crosvm-787c84b51b29c0715c6d3e73aca0148b6b112440.zip
sys_util: recv entire UnixSeqpacket packets into Vec
This change adds the `recv_*_vec` suite of methods for getting an entire
packet into a `Vec` without needing to know the packet size through some
other means.

TEST=cargo test -p sys_util -p msg_socket
BUG=None

Change-Id: Ia4f931ccb91f6de6ee2103387fd95dfad3d3d38b
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2034025
Commit-Queue: Zach Reizner <zachr@chromium.org>
Tested-by: Zach Reizner <zachr@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Auto-Submit: Zach Reizner <zachr@chromium.org>
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: Stephen Barber <smbarber@chromium.org>
-rw-r--r--msg_socket/src/lib.rs30
-rw-r--r--sys_util/src/net.rs114
-rw-r--r--sys_util/src/sock_ctrl_msg.rs4
3 files changed, 133 insertions, 15 deletions
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
index c6e3a38..5b9f9ce 100644
--- a/msg_socket/src/lib.rs
+++ b/msg_socket/src/lib.rs
@@ -145,33 +145,33 @@ pub trait MsgReceiver: AsRef<UnixSeqpacket> {
     fn recv(&self) -> MsgResult<Self::M> {
         let msg_size = Self::M::msg_size();
         let fd_size = Self::M::max_fd_count();
-        let mut msg_buffer: Vec<u8> = vec![0; msg_size];
-        let mut fd_buffer: Vec<RawFd> = vec![0; fd_size];
 
         let sock: &UnixSeqpacket = self.as_ref();
 
-        let (recv_msg_size, recv_fd_size) = {
+        let (msg_buffer, fd_buffer) = {
             if fd_size == 0 {
-                let size = sock
-                    .recv(&mut msg_buffer)
-                    .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?;
-                (size, 0)
+                (
+                    sock.recv_as_vec().map_err(|e| {
+                        MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0)))
+                    })?,
+                    vec![],
+                )
             } else {
-                sock.recv_with_fds(&mut msg_buffer, &mut fd_buffer)
-                    .map_err(MsgError::Recv)?
+                sock.recv_as_vec_with_fds()
+                    .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?
             }
         };
-        if msg_size != recv_msg_size {
+
+        if msg_size != msg_buffer.len() {
             return Err(MsgError::BadRecvSize {
                 expected: msg_size,
-                actual: recv_msg_size,
+                actual: msg_buffer.len(),
             });
         }
         // Safe because fd buffer is read from socket.
-        let (v, read_fd_size) = unsafe {
-            Self::M::read_from_buffer(&msg_buffer[0..recv_msg_size], &fd_buffer[0..recv_fd_size])?
-        };
-        if recv_fd_size != read_fd_size {
+        let (v, read_fd_size) =
+            unsafe { Self::M::read_from_buffer(&msg_buffer[..], &fd_buffer[..])? };
+        if fd_buffer.len() != read_fd_size {
             return Err(MsgError::NotExpectFd);
         }
         Ok(v)
diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs
index 70f975b..71ab3ee 100644
--- a/sys_util/src/net.rs
+++ b/sys_util/src/net.rs
@@ -16,6 +16,10 @@ use std::path::PathBuf;
 use std::ptr::null_mut;
 use std::time::Duration;
 
+use libc::{recvfrom, MSG_PEEK, MSG_TRUNC};
+
+use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
+
 // Offset of sun_path in structure sockaddr_un.
 fn sun_path_offset() -> usize {
     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
@@ -149,6 +153,28 @@ impl UnixSeqpacket {
         }
     }
 
+    /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
+    /// respecting the blocking and timeout settings of the underlying socket.
+    pub fn next_packet_size(&self) -> io::Result<usize> {
+        // This form of recvfrom doesn't modify any data because all null pointers are used. We only
+        // use the return value and check for errors on an FD owned by this structure.
+        let ret = unsafe {
+            recvfrom(
+                self.fd,
+                null_mut(),
+                0,
+                MSG_TRUNC | MSG_PEEK,
+                null_mut(),
+                null_mut(),
+            )
+        };
+        if ret < 0 {
+            Err(io::Error::last_os_error())
+        } else {
+            Ok(ret as usize)
+        }
+    }
+
     /// Write data from a given buffer to the socket fd
     ///
     /// # Arguments
@@ -193,6 +219,52 @@ impl UnixSeqpacket {
         }
     }
 
+    /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
+    ///
+    /// # Arguments
+    /// * `buf` - A mut reference to a `Vec` to resize and read into.
+    ///
+    /// # Errors
+    /// Returns error when `libc::read` or `get_readable_bytes` failed.
+    pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
+        let packet_size = self.next_packet_size()?;
+        buf.resize(packet_size, 0);
+        let read_bytes = self.recv(buf)?;
+        buf.resize(read_bytes, 0);
+        Ok(())
+    }
+
+    /// Read data from the socket fd to a new `Vec`.
+    ///
+    /// # Returns
+    /// * `vec` - A new `Vec` with the entire received packet.
+    ///
+    /// # Errors
+    /// Returns error when `libc::read` or `get_readable_bytes` failed.
+    pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
+        let mut buf = Vec::new();
+        self.recv_to_vec(&mut buf)?;
+        Ok(buf)
+    }
+
+    /// Read data and fds from the socket fd to a new pair of `Vec`.
+    ///
+    /// # Returns
+    /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
+    /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
+    ///
+    /// # Errors
+    /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
+    pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
+        let packet_size = self.next_packet_size()?;
+        let mut buf = vec![0; packet_size];
+        let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
+        let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?;
+        buf.resize(read_bytes, 0);
+        fd_buf.resize(read_fds, -1);
+        Ok((buf, fd_buf))
+    }
+
     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
         let timeval = match timeout {
             Some(t) => {
@@ -412,6 +484,7 @@ impl Drop for UnlinkUnixSeqpacketListener {
 mod tests {
     use super::*;
     use std::env;
+    use std::io::ErrorKind;
     use std::path::PathBuf;
 
     fn tmpdir() -> PathBuf {
@@ -584,4 +657,45 @@ mod tests {
         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
         assert_eq!(s2.get_readable_bytes().unwrap(), 0);
     }
+
+    #[test]
+    fn unix_seqpacket_next_packet_size() {
+        let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
+        let data1 = &[0, 1, 2, 3, 4];
+        s1.send(data1).expect("failed to send data");
+
+        assert_eq!(s2.next_packet_size().unwrap(), 5);
+        s1.set_read_timeout(Some(Duration::from_micros(1)))
+            .expect("failed to set read timeout");
+        assert_eq!(
+            s1.next_packet_size().unwrap_err().kind(),
+            ErrorKind::WouldBlock
+        );
+        drop(s2);
+        assert_eq!(
+            s1.next_packet_size().unwrap_err().kind(),
+            ErrorKind::ConnectionReset
+        );
+    }
+
+    #[test]
+    fn unix_seqpacket_recv_to_vec() {
+        let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
+        let data1 = &[0, 1, 2, 3, 4];
+        s1.send(data1).expect("failed to send data");
+
+        let recv_data = &mut vec![];
+        s2.recv_to_vec(recv_data).expect("failed to recv data");
+        assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
+    }
+
+    #[test]
+    fn unix_seqpacket_recv_as_vec() {
+        let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
+        let data1 = &[0, 1, 2, 3, 4];
+        s1.send(data1).expect("failed to send data");
+
+        let recv_data = s2.recv_as_vec().expect("failed to recv data");
+        assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
+    }
 }
diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs
index 13b9b0c..d4b953b 100644
--- a/sys_util/src/sock_ctrl_msg.rs
+++ b/sys_util/src/sock_ctrl_msg.rs
@@ -213,6 +213,9 @@ fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(u
     Ok((total_read as usize, in_fds_count))
 }
 
+/// The maximum number of FDs that can be sent in a single send.
+pub const SCM_SOCKET_MAX_FD_COUNT: usize = 253;
+
 /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
 /// `recvmsg`.
 pub trait ScmSocket {
@@ -292,6 +295,7 @@ impl ScmSocket for UnixStream {
         self.as_raw_fd()
     }
 }
+
 impl ScmSocket for UnixSeqpacket {
     fn socket_fd(&self) -> RawFd {
         self.as_raw_fd()