From b7196e2a1c1eb7123e7eace5418b7eb4a3e24dbe Mon Sep 17 00:00:00 2001 From: Zach Reizner Date: Thu, 21 Feb 2019 20:49:07 -0800 Subject: sys_util: add seqpacket features This change adds the ability to make seqpacket pairs, and set the timeouts of the sockets. This also adds a TcpListener style api for accepting UnixSeqpacket sockets. TEST=cargo test -p sys_util BUG=chromium:848187 Change-Id: I9f9bb5224cdfaf257d8e4a1bdaac8128be874951 Reviewed-on: https://chromium-review.googlesource.com/1482371 Commit-Ready: Zach Reizner Tested-by: kokoro Tested-by: Zach Reizner Reviewed-by: Daniel Verkamp --- sys_util/src/net.rs | 386 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 361 insertions(+), 25 deletions(-) (limited to 'sys_util/src/net.rs') diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs index 472228e..4a88ebc 100644 --- a/sys_util/src/net.rs +++ b/sys_util/src/net.rs @@ -1,13 +1,20 @@ // Copyright 2018 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. + +use std::ffi::OsString; +use std::fs::remove_file; use std::io; use std::mem; +use std::ops::Deref; use std::os::unix::{ - ffi::OsStrExt, + ffi::{OsStrExt, OsStringExt}, io::{AsRawFd, FromRawFd, RawFd}, }; use std::path::Path; +use std::path::PathBuf; +use std::ptr::null_mut; +use std::time::Duration; // Offset of sun_path in structure sockaddr_un. fn sun_path_offset() -> usize { @@ -97,6 +104,30 @@ impl UnixSeqpacket { Ok(UnixSeqpacket { fd }) } + /// Creates a pair of connected `SOCK_SEQPACKET` sockets. + /// + /// Both returned file descriptors have the `CLOEXEC` flag set.s + pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> { + let mut fds = [0, 0]; + unsafe { + // Safe because we give enough space to store all the fds and we check the return value. + let ret = libc::socketpair( + libc::AF_UNIX, + libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC, + 0, + &mut fds[0], + ); + if ret == 0 { + Ok(( + UnixSeqpacket::from_raw_fd(fds[0]), + UnixSeqpacket::from_raw_fd(fds[1]), + )) + } else { + Err(io::Error::last_os_error()) + } + } + } + /// Clone the underlying FD. pub fn try_clone(&self) -> io::Result { // Calling `dup` is safe as the kernel doesn't touch any user memory it the process. @@ -108,6 +139,29 @@ impl UnixSeqpacket { } } + /// Gets the number of bytes that can be read from this socket without blocking. + pub fn get_readable_bytes(&self) -> io::Result { + let mut byte_count = 0 as libc::c_int; + let ret = unsafe { libc::ioctl(self.fd, libc::FIONREAD, &mut byte_count) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(byte_count as usize) + } + } + + #[deprecated] + /// Alias for `send`. + pub fn write(&self, buf: &[u8]) -> io::Result { + self.send(buf) + } + + #[deprecated] + /// Alias for `recv`. + pub fn read(&self, buf: &mut [u8]) -> io::Result { + self.recv(buf) + } + /// Write data from a given buffer to the socket fd /// /// # Arguments @@ -118,7 +172,7 @@ impl UnixSeqpacket { /// /// # Errors /// Returns error when `libc::write` failed. - pub fn write(&self, buf: &[u8]) -> io::Result { + pub fn send(&self, buf: &[u8]) -> io::Result { // Safe since we make sure the input `count` == `buf.len()` and handle the returned error. unsafe { let ret = libc::write(self.fd, buf.as_ptr() as *const _, buf.len()); @@ -140,7 +194,7 @@ impl UnixSeqpacket { /// /// # Errors /// Returns error when `libc::read` failed. - pub fn read(&self, buf: &mut [u8]) -> io::Result { + pub fn recv(&self, buf: &mut [u8]) -> io::Result { // Safe since we make sure the input `count` == `buf.len()` and handle the returned error. unsafe { let ret = libc::read(self.fd, buf.as_mut_ptr() as *mut _, buf.len()); @@ -152,9 +206,52 @@ impl UnixSeqpacket { } } - // Get `RawFd` from this server_socket - fn socket_fd(&self) -> RawFd { - self.fd + fn set_timeout(&self, timeout: Option, kind: libc::c_int) -> io::Result<()> { + let timeval = match timeout { + Some(t) => { + if t.as_secs() == 0 && t.subsec_micros() == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "zero timeout duration is invalid", + )); + } + libc::timeval { + tv_sec: t.as_secs() as libc::time_t, + tv_usec: t.subsec_micros() as libc::suseconds_t, + } + } + None => libc::timeval { + tv_sec: 0, + tv_usec: 0, + }, + }; + // Safe because we own the fd, and the length of the pointer's data is the same as the + // passed in length parameter. The level argument is valid, the kind is assumed to be valid, + // and the return value is checked. + let ret = unsafe { + libc::setsockopt( + self.fd, + libc::SOL_SOCKET, + kind, + &timeval as *const libc::timeval as *const libc::c_void, + mem::size_of::() as libc::socklen_t, + ) + }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } + + /// Sets or removes the timeout for read/recv operations on this socket. + pub fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + self.set_timeout(timeout, libc::SO_RCVTIMEO) + } + + /// Sets or removes the timeout for write/send operations on this socket. + pub fn set_write_timeout(&self, timeout: Option) -> io::Result<()> { + self.set_timeout(timeout, libc::SO_SNDTIMEO) } } @@ -176,7 +273,149 @@ impl FromRawFd for UnixSeqpacket { impl AsRawFd for UnixSeqpacket { fn as_raw_fd(&self) -> RawFd { - self.socket_fd() + self.fd + } +} + +/// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets. +pub struct UnixSeqpacketListener { + fd: RawFd, +} + +impl UnixSeqpacketListener { + /// Creates a new `UnixSeqpacketListener` bound to the given path. + pub fn bind>(path: P) -> io::Result { + // Safe socket initialization since we handle the returned error. + let fd = unsafe { + match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) { + -1 => return Err(io::Error::last_os_error()), + fd => fd, + } + }; + + let (addr, len) = sockaddr_un(path.as_ref())?; + // Safe connect since we handle the error and use the right length generated from + // `sockaddr_un`. + unsafe { + let ret = handle_eintr_errno!(libc::bind(fd, &addr as *const _ as *const _, len)); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + let ret = handle_eintr_errno!(libc::listen(fd, 128)); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + } + Ok(UnixSeqpacketListener { fd }) + } + + /// Blocks for and accepts a new incoming connection and returns the socket associated with that + /// connection. + /// + /// The returned socket has the close-on-exec flag set. + pub fn accept(&self) -> io::Result { + // Safe because we own this fd and the kernel will not write to null pointers. + let ret = unsafe { libc::accept4(self.fd, null_mut(), null_mut(), libc::SOCK_CLOEXEC) }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + // Safe because we checked the return value of accept. Therefore, the return value must be a + // valid socket. + Ok(unsafe { UnixSeqpacket::from_raw_fd(ret) }) + } + + /// Gets the path that this listener is bound to. + pub fn path(&self) -> io::Result { + let mut addr = libc::sockaddr_un { + sun_family: libc::AF_UNIX as libc::sa_family_t, + sun_path: [0; 108], + }; + let sun_path_offset = (&addr.sun_path as *const _ as usize + - &addr.sun_family as *const _ as usize) + as libc::socklen_t; + let mut len = mem::size_of::() as libc::socklen_t; + // Safe because the length given matches the length of the data of the given pointer, and we + // check the return value. + let ret = unsafe { + handle_eintr_errno!(libc::getsockname( + self.fd, + &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr, + &mut len + )) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + if addr.sun_family != libc::AF_UNIX as libc::sa_family_t + || addr.sun_path[0] == 0 + || len < 1 + sun_path_offset + { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "getsockname on socket returned invalid value", + )); + } + + let path_os_str = OsString::from_vec( + addr.sun_path[..(len - sun_path_offset - 1) as usize] + .iter() + .map(|&c| c as _) + .collect(), + ); + Ok(path_os_str.into()) + } +} + +impl Drop for UnixSeqpacketListener { + fn drop(&mut self) { + // Safe if the UnixSeqpacketListener is created from Self::listen. + unsafe { + libc::close(self.fd); + } + } +} + +impl FromRawFd for UnixSeqpacketListener { + // Unsafe in drop function + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self { fd } + } +} + +impl AsRawFd for UnixSeqpacketListener { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +/// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped. +pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener); +impl AsRef for UnlinkUnixSeqpacketListener { + fn as_ref(&self) -> &UnixSeqpacketListener { + &self.0 + } +} + +impl AsRawFd for UnlinkUnixSeqpacketListener { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl Deref for UnlinkUnixSeqpacketListener { + type Target = UnixSeqpacketListener; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Drop for UnlinkUnixSeqpacketListener { + fn drop(&mut self) { + if let Ok(path) = self.0.path() { + if let Err(e) = remove_file(path) { + warn!("failed to remove control socket file: {:?}", e); + } + } } } @@ -186,6 +425,10 @@ mod tests { use std::env; use std::path::PathBuf; + fn tmpdir() -> PathBuf { + env::temp_dir() + } + #[test] fn sockaddr_un_zero_length_input() { let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed"); @@ -233,31 +476,124 @@ mod tests { assert!(res.is_err()); } - fn tmpdir() -> PathBuf { - env::temp_dir() - } - - fn mock_server_socket(socket_path: &Path) { - unsafe { - let socket_fd = libc::socket(libc::PF_UNIX, libc::SOCK_SEQPACKET, 0); - assert!(socket_fd > 0); - // Bind socket to path - let (addr, len) = sockaddr_un(socket_path).unwrap(); - libc::unlink(&addr.sun_path as *const _ as *const _); - let rc = libc::bind(socket_fd, &addr as *const _ as *const _, len); - assert_eq!(rc, 0); - // Mark the `socket_fd` as passive socket - let rc = libc::listen(socket_fd, 5); - assert_eq!(rc, 0); - }; + #[test] + fn unix_seqpacket_listener_path() { + let mut socket_path = tmpdir(); + socket_path.push("unix_seqpacket_listener_path"); + let listener = UnlinkUnixSeqpacketListener( + UnixSeqpacketListener::bind(&socket_path) + .expect("failed to create UnixSeqpacketListener"), + ); + let listener_path = listener.path().expect("failed to get socket listener path"); + assert_eq!(socket_path, listener_path); } #[test] fn unix_seqpacket_path_exists_pass() { let mut socket_path = tmpdir(); socket_path.push("path_to_socket"); - mock_server_socket(socket_path.as_path()); + let _listener = UnlinkUnixSeqpacketListener( + UnixSeqpacketListener::bind(&socket_path) + .expect("failed to create UnixSeqpacketListener"), + ); let _res = UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed"); } + + #[test] + fn unix_seqpacket_path_listener_accept() { + let mut socket_path = tmpdir(); + socket_path.push("path_listerner_accept"); + let listener = UnlinkUnixSeqpacketListener( + UnixSeqpacketListener::bind(&socket_path) + .expect("failed to create UnixSeqpacketListener"), + ); + let s1 = + UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed"); + + let s2 = listener.accept().expect("UnixSeqpacket::accept failed"); + + let data1 = &[0, 1, 2, 3, 4]; + let data2 = &[10, 11, 12, 13, 14]; + s2.send(data2).expect("failed to send data2"); + s1.send(data1).expect("failed to send data1"); + let recv_data = &mut [0; 5]; + s2.recv(recv_data).expect("failed to recv data"); + assert_eq!(data1, recv_data); + s1.recv(recv_data).expect("failed to recv data"); + assert_eq!(data2, recv_data); + } + + #[test] + #[should_panic] + fn unix_seqpacket_zero_timeout() { + let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + // Timeouts less than a microsecond are too small and round to zero. + s1.set_read_timeout(Some(Duration::from_nanos(10))) + .expect("failed to set read timeout for socket"); + } + + #[test] + fn unix_seqpacket_read_timeout() { + let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + s1.set_read_timeout(Some(Duration::from_millis(1))) + .expect("failed to set read timeout for socket"); + let _ = s1.recv(&mut [0]); + } + + #[test] + fn unix_seqpacket_write_timeout() { + let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + s1.set_write_timeout(Some(Duration::from_millis(1))) + .expect("failed to set write timeout for socket"); + } + + #[test] + fn unix_seqpacket_send_recv() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + let data2 = &[10, 11, 12, 13, 14]; + s2.send(data2).expect("failed to send data2"); + s1.send(data1).expect("failed to send data1"); + let recv_data = &mut [0; 5]; + s2.recv(recv_data).expect("failed to recv data"); + assert_eq!(data1, recv_data); + s1.recv(recv_data).expect("failed to recv data"); + assert_eq!(data2, recv_data); + } + + #[test] + fn unix_seqpacket_send_fragments() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + let data1 = &[0, 1, 2, 3, 4]; + let data2 = &[10, 11, 12, 13, 14, 15, 16]; + s1.send(data1).expect("failed to send data1"); + s1.send(data2).expect("failed to send data2"); + + let recv_data = &mut [0; 32]; + let size = s2.recv(recv_data).expect("failed to recv data"); + assert_eq!(size, data1.len()); + assert_eq!(data1, &recv_data[0..size]); + + let size = s2.recv(recv_data).expect("failed to recv data"); + assert_eq!(size, data2.len()); + assert_eq!(data2, &recv_data[0..size]); + } + + #[test] + fn unix_seqpacket_get_readable_bytes() { + let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair"); + assert_eq!(s1.get_readable_bytes().unwrap(), 0); + assert_eq!(s2.get_readable_bytes().unwrap(), 0); + let data1 = &[0, 1, 2, 3, 4]; + s1.send(data1).expect("failed to send data"); + + assert_eq!(s1.get_readable_bytes().unwrap(), 0); + assert_eq!(s2.get_readable_bytes().unwrap(), data1.len()); + + let recv_data = &mut [0; 5]; + s2.recv(recv_data).expect("failed to recv data"); + assert_eq!(s1.get_readable_bytes().unwrap(), 0); + assert_eq!(s2.get_readable_bytes().unwrap(), 0); + } } -- cgit 1.4.1