diff options
Diffstat (limited to 'sys_util/src/net.rs')
-rw-r--r-- | sys_util/src/net.rs | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs index 70f975b..71ab3ee 100644 --- a/sys_util/src/net.rs +++ b/sys_util/src/net.rs @@ -16,6 +16,10 @@ use std::path::PathBuf; use std::ptr::null_mut; use std::time::Duration; +use libc::{recvfrom, MSG_PEEK, MSG_TRUNC}; + +use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT}; + // Offset of sun_path in structure sockaddr_un. fn sun_path_offset() -> usize { // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer. @@ -149,6 +153,28 @@ impl UnixSeqpacket { } } + /// Gets the number of bytes in the next packet. This blocks as if `recv` were called, + /// respecting the blocking and timeout settings of the underlying socket. + pub fn next_packet_size(&self) -> io::Result<usize> { + // This form of recvfrom doesn't modify any data because all null pointers are used. We only + // use the return value and check for errors on an FD owned by this structure. + let ret = unsafe { + recvfrom( + self.fd, + null_mut(), + 0, + MSG_TRUNC | MSG_PEEK, + null_mut(), + null_mut(), + ) + }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret as usize) + } + } + /// Write data from a given buffer to the socket fd /// /// # Arguments @@ -193,6 +219,52 @@ impl UnixSeqpacket { } } + /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size. + /// + /// # Arguments + /// * `buf` - A mut reference to a `Vec` to resize and read into. + /// + /// # Errors + /// Returns error when `libc::read` or `get_readable_bytes` failed. + pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> { + let packet_size = self.next_packet_size()?; + buf.resize(packet_size, 0); + let read_bytes = self.recv(buf)?; + buf.resize(read_bytes, 0); + Ok(()) + } + + /// Read data from the socket fd to a new `Vec`. + /// + /// # Returns + /// * `vec` - A new `Vec` with the entire received packet. + /// + /// # Errors + /// Returns error when `libc::read` or `get_readable_bytes` failed. + pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> { + let mut buf = Vec::new(); + self.recv_to_vec(&mut buf)?; + Ok(buf) + } + + /// Read data and fds from the socket fd to a new pair of `Vec`. + /// + /// # Returns + /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes. + /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds. + /// + /// # Errors + /// Returns error when `recv_with_fds` or `get_readable_bytes` failed. + pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> { + let packet_size = self.next_packet_size()?; + let mut buf = vec![0; packet_size]; + let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT]; + let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?; + buf.resize(read_bytes, 0); + fd_buf.resize(read_fds, -1); + Ok((buf, fd_buf)) + } + fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> { let timeval = match timeout { Some(t) => { @@ -412,6 +484,7 @@ impl Drop for UnlinkUnixSeqpacketListener { mod tests { use super::*; use std::env; + use std::io::ErrorKind; use std::path::PathBuf; fn tmpdir() -> PathBuf { @@ -584,4 +657,45 @@ mod tests { assert_eq!(s1.get_readable_bytes().unwrap(), 0); assert_eq!(s2.get_readable_bytes().unwrap(), 0); } + + #[test] + fn unix_seqpacket_next_packet_size() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + assert_eq!(s2.next_packet_size().unwrap(), 5); + s1.set_read_timeout(Some(Duration::from_micros(1))) + .expect("failed to set read timeout"); + assert_eq!( + s1.next_packet_size().unwrap_err().kind(), + ErrorKind::WouldBlock + ); + drop(s2); + assert_eq!( + s1.next_packet_size().unwrap_err().kind(), + ErrorKind::ConnectionReset + ); + } + + #[test] + fn unix_seqpacket_recv_to_vec() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + let recv_data = &mut vec![]; + s2.recv_to_vec(recv_data).expect("failed to recv data"); + assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]); + } + + #[test] + fn unix_seqpacket_recv_as_vec() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + let recv_data = s2.recv_as_vec().expect("failed to recv data"); + assert_eq!(recv_data, vec![0, 1, 2, 3, 4]); + } } |