diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-08-11 10:49:38 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2021-05-11 12:19:50 +0000 |
commit | fe5750a3854c98635755cd9d0ceb05de896c0e67 (patch) | |
tree | 28955f2094e0903a268e4f99eb684d27f1d521fe /vhost_rs/src/vhost_user | |
parent | 8a7e4e902a4950b060ea23b40c0dfce7bfa1b2cb (diff) | |
download | crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar.gz crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar.bz2 crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar.lz crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar.xz crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.tar.zst crosvm-fe5750a3854c98635755cd9d0ceb05de896c0e67.zip |
devices: port vhost-user-net from cloud-hypervisor
This is the cloud-hypervisor vhost-user-net code, modified just enough to compile as part of crosvm. There is currently no way to run crosvm with a vhost-user-net device, and even if there were, it wouldn't work without some further fixes.
Diffstat (limited to 'vhost_rs/src/vhost_user')
-rw-r--r-- | vhost_rs/src/vhost_user/connection.rs | 737 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/dummy_slave.rs | 250 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/master.rs | 757 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/master_req_handler.rs | 258 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/message.rs | 812 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/mod.rs | 251 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/slave.rs | 48 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/slave_req_handler.rs | 582 | ||||
-rw-r--r-- | vhost_rs/src/vhost_user/sock_ctrl_msg.rs | 464 |
9 files changed, 4159 insertions, 0 deletions
diff --git a/vhost_rs/src/vhost_user/connection.rs b/vhost_rs/src/vhost_user/connection.rs new file mode 100644 index 0000000..69439dd --- /dev/null +++ b/vhost_rs/src/vhost_user/connection.rs @@ -0,0 +1,737 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Structs for Unix Domain Socket listener and endpoint. + +#![allow(dead_code)] + +use libc::{c_void, iovec}; +use std::io::ErrorKind; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::{mem, slice}; + +use super::message::*; +use super::sock_ctrl_msg::ScmSocket; +use super::{Error, Result}; + +/// Unix domain socket listener for accepting incoming connections. +pub struct Listener { + fd: UnixListener, + path: String, +} + +impl Listener { + /// Create a unix domain socket listener. + /// + /// # Return: + /// * - the new Listener object on success. + /// * - SocketError: failed to create listener socket. + pub fn new(path: &str, unlink: bool) -> Result<Self> { + if unlink { + let _ = std::fs::remove_file(path); + } + let fd = UnixListener::bind(path).map_err(Error::SocketError)?; + Ok(Listener { + fd, + path: path.to_string(), + }) + } + + /// Accept an incoming connection. + /// + /// # Return: + /// * - Some(UnixStream): new UnixStream object if new incoming connection is available. + /// * - None: no incoming connection available. + /// * - SocketError: errors from accept(). + pub fn accept(&self) -> Result<Option<UnixStream>> { + loop { + match self.fd.accept() { + Ok((socket, _addr)) => return Ok(Some(socket)), + Err(e) => { + match e.kind() { + // No incoming connection available. + ErrorKind::WouldBlock => return Ok(None), + // New connection closed by peer. + ErrorKind::ConnectionAborted => return Ok(None), + // Interrupted by signals, retry + ErrorKind::Interrupted => continue, + _ => return Err(Error::SocketError(e)), + } + } + } + } + } + + /// Change blocking status on the listener. + /// + /// # Return: + /// * - () on success. + /// * - SocketError: failure from set_nonblocking(). + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.fd.set_nonblocking(block).map_err(Error::SocketError) + } +} + +impl AsRawFd for Listener { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl Drop for Listener { + fn drop(&mut self) { + let _ = std::fs::remove_file(self.path.clone()); + } +} + +/// Unix domain socket endpoint for vhost-user connection. +pub(super) struct Endpoint<R: Req> { + sock: UnixStream, + _r: PhantomData<R>, +} + +impl<R: Req> Endpoint<R> { + /// Create a new stream by connecting to server at `str`. + /// + /// # Return: + /// * - the new Endpoint object on success. + /// * - SocketConnect: failed to connect to peer. + pub fn connect(path: &str) -> Result<Self> { + let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; + Ok(Self::from_stream(sock)) + } + + /// Create an endpoint from a stream object. + pub fn from_stream(sock: UnixStream) -> Self { + Endpoint { + sock, + _r: PhantomData, + } + } + + /// Sends bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let rfds = match fds { + Some(rfds) => rfds, + _ => &[], + }; + self.sock.send_with_fds(iovs, rfds).map_err(Into::into) + } + + /// Sends bytes from a slice over the socket with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> { + self.send_iovec(&[data], fds) + } + + /// Sends a header-only message with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_header( + &mut self, + hdr: &VhostUserMsgHeader<R>, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr. + let iovs = unsafe { + [slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + )] + }; + let bytes = self.send_iovec(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header and body. Optional file descriptors may be attached to + /// the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_message<T: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr and body. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + ] + }; + let bytes = self.send_iovec(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header, body and payload. Optional file descriptors + /// may also be attached to the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - OversizedMsg: message size is too big. + /// * - PartialMessage: received a partial message. + /// * - IncorrectFds: wrong number of attached fds. + pub fn send_message_with_payload<T: Sized, P: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + payload: &[P], + fds: Option<&[RawFd]>, + ) -> Result<()> { + let len = payload.len() * mem::size_of::<P>(); + if len > MAX_MSG_SIZE - mem::size_of::<T>() { + return Err(Error::OversizedMsg); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(Error::IncorrectFds); + } + } + + // Safe because there can't be other mutable reference to hdr, body and payload. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + slice::from_raw_parts(payload.as_ptr() as *const u8, len), + ] + }; + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len; + let len = self.send_iovec(&iovs, fds)?; + if len != total { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Reads bytes from the socket into the given scatter/gather vectors. + /// + /// # Return: + /// * - (number of bytes received, buf) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { + let mut rbuf = vec![0u8; len]; + let mut iovs = [iovec { + iov_base: rbuf.as_mut_ptr() as *mut c_void, + iov_len: len, + }]; + let (bytes, _) = self.sock.recv_with_fds(&mut iovs, &mut [])?; + Ok((bytes, rbuf)) + } + + /// Reads bytes from the socket into the given scatter/gather vectors with optional attached + /// file descriptors. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; + let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?; + let rfds = match fds { + 0 => None, + n => { + let mut fds = Vec::with_capacity(n); + fds.extend_from_slice(&fd_array[0..n]); + Some(fds) + } + }; + + Ok((bytes, rfds)) + } + + /// Reads bytes from the socket into a new buffer with optional attached + /// file descriptors. Received file descriptors are set close-on-exec. + /// + /// # Return: + /// * - (number of bytes received, buf, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_buf( + &mut self, + buf_size: usize, + ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + let mut buf = vec![0u8; buf_size]; + let (bytes, rfds) = { + let mut iovs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf_size, + }]; + self.recv_into_iovec(&mut iovs)? + }; + Ok((bytes, buf, rfds)) + } + + /// Receive a header-only message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }]; + let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?; + + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, rfds)) + } + + /// Receive a message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes != total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, rfds)) + } + + /// Receive a message with header and optional content. Callers need to + /// pre-allocate a big enough buffer to receive the message body and + /// optional payload. If there are attached file descriptor associated + /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors + /// will be accepted and all other file descriptor will be discard + /// silently. + /// + /// # Return: + /// * - (message header, message size, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body_into_buf( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?; + + if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + } + + /// Receive a message with optional payload and attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] + pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes < total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, bytes - total, rfds)) + } + + /// Close all raw file descriptors. + pub fn close_rfds(rfds: Option<Vec<RawFd>>) { + if let Some(fds) = rfds { + for fd in fds { + // safe because the rawfds are valid and we don't care about the result. + let _ = unsafe { libc::close(fd) }; + } + } + } +} + +impl<T: Req> AsRawFd for Endpoint<T> { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + extern crate tempfile; + + use self::tempfile::tempfile; + use super::*; + use libc; + use std::fs::File; + use std::io::{Read, Seek, SeekFrom, Write}; + use std::os::unix::io::FromRawFd; + + const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener"; + const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection"; + const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data"; + const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd"; + const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send"; + + #[test] + fn create_listener() { + let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + } + + #[test] + fn accept_connection() { + let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + // accept on a fd without incoming connection + let conn = listener.accept().unwrap(); + assert!(conn.is_none()); + + listener.set_nonblocking(true).unwrap(); + + // accept on a closed fd + unsafe { + libc::close(listener.as_raw_fd()); + } + let conn2 = listener.accept(); + assert!(conn2.is_err()); + } + + #[test] + fn send_data() { + let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let mut len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..bytes]); + + len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + } + + #[test] + fn send_fd() { + let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut fd = tempfile().unwrap(); + write!(fd, "test").unwrap(); + + // Normal case for sending/receiving file descriptors + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let len = master + .send_slice(&buf1[..], Some(&[fd.as_raw_fd()])) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 1); + let mut file = unsafe { File::from_raw_fd(fds[0]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + + // Following communication pattern should work: + // Sending side: data(header, body) with fds + // Receiving side: data(header) with fds, data(body) + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data(header, body) with fds + // Receiving side: data(header), data(body) with fds + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf4) = slave.recv_data(2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf4[..]); + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should work: + // Sending side: data, data with fds + // Receiving side: data, data with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_none()); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data1, data2 with fds + // Receiving side: data + partial of data2, left of data2 with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _) = slave.recv_data(5).unwrap(); + assert_eq!(bytes, 5); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 3); + assert!(rfds.is_none()); + + // If the target fd array is too small, extra file descriptors will get lost. + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert!(rfds.is_some()); + + Endpoint::<MasterReq>::close_rfds(rfds); + Endpoint::<MasterReq>::close_rfds(None); + } + + #[test] + fn send_recv() { + let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut hdr1 = + VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32); + hdr1.set_need_reply(true); + let features1 = 0x1u64; + master.send_message(&hdr1, &features1, None).unwrap(); + + let mut features2 = 0u64; + let slice = unsafe { + slice::from_raw_parts_mut( + (&mut features2 as *mut u64) as *mut u8, + mem::size_of::<u64>(), + ) + }; + let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + assert_eq!(hdr1, hdr2); + assert_eq!(bytes, 8); + assert_eq!(features1, features2); + assert!(rfds.is_none()); + + master.send_header(&hdr1, None).unwrap(); + let (hdr2, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr1, hdr2); + assert!(rfds.is_none()); + } +} diff --git a/vhost_rs/src/vhost_user/dummy_slave.rs b/vhost_rs/src/vhost_user/dummy_slave.rs new file mode 100644 index 0000000..53887e2 --- /dev/null +++ b/vhost_rs/src/vhost_user/dummy_slave.rs @@ -0,0 +1,250 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::message::*; +use super::*; +use std::os::unix::io::RawFd; + +pub const MAX_QUEUE_NUM: usize = 2; +pub const MAX_VRING_NUM: usize = 256; +pub const VIRTIO_FEATURES: u64 = 0x40000003; + +#[derive(Default)] +pub struct DummySlaveReqHandler { + pub owned: bool, + pub features_acked: bool, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub queue_num: usize, + pub vring_num: [u32; MAX_QUEUE_NUM], + pub vring_base: [u32; MAX_QUEUE_NUM], + pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub vring_started: [bool; MAX_QUEUE_NUM], + pub vring_enabled: [bool; MAX_QUEUE_NUM], +} + +impl DummySlaveReqHandler { + pub fn new() -> Self { + DummySlaveReqHandler { + queue_num: MAX_QUEUE_NUM, + ..Default::default() + } + } +} + +impl VhostUserSlaveReqHandler for DummySlaveReqHandler { + fn set_owner(&mut self) -> Result<()> { + if self.owned { + return Err(Error::InvalidOperation); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + self.owned = false; + self.features_acked = false; + self.acked_features = 0; + self.acked_protocol_features = 0; + Ok(()) + } + + fn get_features(&mut self) -> Result<u64> { + Ok(VIRTIO_FEATURES) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + if !self.owned { + return Err(Error::InvalidOperation); + } else if self.features_acked { + return Err(Error::InvalidOperation); + } else if (features & !VIRTIO_FEATURES) != 0 { + return Err(Error::InvalidParam); + } + + self.acked_features = features; + self.features_acked = true; + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, + // the ring is initialized in an enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, + // the ring is initialized in a disabled state. Client must not + // pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has + // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + let vring_enabled = + self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; + for enabled in &mut self.vring_enabled { + *enabled = vring_enabled; + } + + Ok(()) + } + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { + // TODO + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + Ok(MAX_QUEUE_NUM as u64) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { + if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_num[index as usize] = num; + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + _descriptor: u64, + _used: u64, + _available: u64, + _log: u64, + ) -> Result<()> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { + if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_base[index as usize] = base; + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + self.vring_started[index as usize] = false; + Ok(VhostUserVringState::new( + index, + self.vring_base[index as usize], + )) + } + + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.kick_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) }; + } + self.kick_fd[index as usize] = fd; + + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + // + // So we should add fd to event monitor(select, poll, epoll) here. + self.vring_started[index as usize] = true; + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.call_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) }; + } + self.call_fd[index as usize] = fd; + Ok(()) + } + + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.err_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) }; + } + self.err_fd[index as usize] = fd; + Ok(()) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return Err(Error::InvalidOperation); + } else if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + self.vring_enabled[index as usize] = enable; + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>> { + if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if offset < VHOST_USER_CONFIG_OFFSET + || offset >= VHOST_USER_CONFIG_SIZE + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(vec![0xa5; size as usize]) + } + + fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { + let size = buf.len() as u32; + if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if offset < VHOST_USER_CONFIG_OFFSET + || offset >= VHOST_USER_CONFIG_SIZE + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(()) + } +} diff --git a/vhost_rs/src/vhost_user/master.rs b/vhost_rs/src/vhost_user/master.rs new file mode 100644 index 0000000..d1e3877 --- /dev/null +++ b/vhost_rs/src/vhost_user/master.rs @@ -0,0 +1,757 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Struct for vhost-user master. + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex}; + +use vmm_sys_util::eventfd::EventFd; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error as VhostUserError, Result as VhostUserResult}; +use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; +use crate::{Error, Result}; + +/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet. +pub trait VhostUserMaster: VhostBackend { + /// Get the protocol feature bitmask from the underlying vhost implementation. + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + + /// Enable protocol features in the underlying vhost implementation. + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>; + + /// Query how many queues the backend supports. + fn get_queue_num(&mut self) -> Result<u64>; + + /// Signal slave to enable or disable corresponding vring. + /// + /// Slave must not pass data to/from the backend until ring is enabled by + /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been + /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>; + + /// Fetch the contents of the virtio device configuration space. + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>>; + + /// Change the virtio device configuration space. It also can be used for live migration on the + /// destination host to set readonly configuration space fields. + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + + /// Setup slave communication channel. + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>; +} + +fn error_code<T>(err: VhostUserError) -> Result<T> { + Err(Error::VhostUserProtocol(err)) +} + +/// Struct for the vhost-user master endpoint. +#[derive(Clone)] +pub struct Master { + node: Arc<Mutex<MasterInternal>>, +} + +impl Master { + /// Create a new instance. + fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self { + Master { + node: Arc::new(Mutex::new(MasterInternal { + main_sock: ep, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: 0, + acked_protocol_features: 0, + protocol_features_ready: false, + max_queue_num, + error: None, + })), + } + } + + /// Create a new instance from a Unix stream socket. + pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { + Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num) + } + + /// Create a new vhost-user master endpoint. + /// + /// # Arguments + /// * `path` - path of Unix domain socket listener to connect to + pub fn connect(path: &str, max_queue_num: u64) -> Result<Self> { + Ok(Self::new( + Endpoint::<MasterReq>::connect(path)?, + max_queue_num, + )) + } +} + +impl VhostBackend for Master { + /// Get from the underlying vhost implementation the feature bitmask. + fn get_features(&mut self) -> Result<u64> { + let mut node = self.node.lock().unwrap(); + let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.virtio_features = val.value; + Ok(node.virtio_features) + } + + /// Enable features in the underlying vhost implementation using a bitmask. + fn set_features(&mut self, features: u64) -> Result<()> { + let mut node = self.node.lock().unwrap(); + let val = VhostUserU64::new(features); + let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_virtio_features = features & node.virtio_features; + Ok(()) + } + + /// Set the current Master as an owner of the session. + fn set_owner(&mut self) -> Result<()> { + // We unwrap() the return value to assert that we are not expecting threads to ever fail + // while holding the lock. + let mut node = self.node.lock().unwrap(); + let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + let mut node = self.node.lock().unwrap(); + let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + /// Set the memory map regions on the slave so it can translate the vring + /// addresses. In the ancillary data there is an array of file descriptors + fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { + return error_code(VhostUserError::InvalidParam); + } + + let mut ctx = VhostUserMemoryContext::new(); + for region in regions.iter() { + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + let reg = VhostUserMemoryRegion { + guest_phys_addr: region.guest_phys_addr, + memory_size: region.memory_size, + user_addr: region.userspace_addr, + mmap_offset: region.mmap_offset, + }; + ctx.append(®, region.mmap_handle); + } + + let mut node = self.node.lock().unwrap(); + let body = VhostUserMemory::new(ctx.regions.len() as u32); + let hdr = node.send_request_with_payload( + MasterReq::SET_MEM_TABLE, + &body, + ctx.regions.as_slice(), + Some(ctx.fds.as_slice()), + )?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> { + let mut node = self.node.lock().unwrap(); + let val = VhostUserU64::new(base); + if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 + && fd.is_some() + { + let fds = [fd.unwrap()]; + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?; + } else { + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?; + } + Ok(()) + } + + fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { + let mut node = self.node.lock().unwrap(); + let fds = [fd]; + node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; + Ok(()) + } + + /// Set the size of the queue. + fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, num.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the addresses of the different aspects of the vring. + fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num + || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the base offset in the available vring. + fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, base.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let req = VhostUserVringState::new(queue_index as u32, 0); + let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?; + let reply = node.recv_reply::<VhostUserVringState>(&hdr)?; + Ok(reply.num) + } + + /// Set the event file descriptor to signal when buffers are used. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// will be used instead of waiting for the call. + fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor for adding buffers to the vring. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// should be used instead of waiting for a kick. + fn set_vring_kick(&mut self, queue_index: usize, fd: &sys_util::EventFd) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor to signal when error occurs. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. + fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; + Ok(()) + } +} + +impl VhostUserMaster for Master { + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + let mut node = self.node.lock().unwrap(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.protocol_features = val.value; + // Should we support forward compatibility? + // If so just mask out unrecognized flags instead of return errors. + match VhostUserProtocolFeatures::from_bits(node.protocol_features) { + Some(val) => Ok(val), + None => error_code(VhostUserError::InvalidMessage), + } + } + + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { + let mut node = self.node.lock().unwrap(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let val = VhostUserU64::new(features.bits()); + let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_protocol_features = features.bits(); + node.protocol_features_ready = true; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + let mut node = self.node.lock().unwrap(); + if !node.is_feature_mq_available() { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + if val.value > VHOST_USER_MAX_VRINGS { + return error_code(VhostUserError::InvalidMessage); + } + node.max_queue_num = val.value; + Ok(node.max_queue_num) + } + + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { + let mut node = self.node.lock().unwrap(); + // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. + if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } else if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let flag = if enable { 1 } else { 0 }; + let val = VhostUserVringState::new(queue_index as u32, flag); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>> { + let body = VhostUserConfig::new(offset, size, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node.lock().unwrap(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_virtio_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + // TODO: vhost-user spec states that: + // "Master payload: virtio device config space" + // But what content should the payload contains for a get_config() request? + // So current implementation doesn't conform to the spec. + let hdr = node.send_request_with_body(MasterReq::GET_CONFIG, &body, None)?; + let (reply, buf, rfds) = node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + return error_code(VhostUserError::InvalidMessage); + } else if reply.size == 0 { + return error_code(VhostUserError::SlaveInternalError); + } else if reply.size != body.size || reply.size as usize != buf.len() { + return error_code(VhostUserError::InvalidMessage); + } + Ok(buf) + } + + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + if buf.len() > MAX_MSG_SIZE { + return error_code(VhostUserError::InvalidParam); + } + let body = VhostUserConfig::new(offset, buf.len() as u32, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node.lock().unwrap(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_virtio_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { + let mut node = self.node.lock().unwrap(); + if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let fds = [fd]; + node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; + Ok(()) + } +} + +impl AsRawFd for Master { + fn as_raw_fd(&self) -> RawFd { + let node = self.node.lock().unwrap(); + node.main_sock.as_raw_fd() + } +} + +/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table(). +struct VhostUserMemoryContext { + regions: VhostUserMemoryPayload, + fds: Vec<RawFd>, +} + +impl VhostUserMemoryContext { + /// Create a context object. + pub fn new() -> Self { + VhostUserMemoryContext { + regions: VhostUserMemoryPayload::new(), + fds: Vec::new(), + } + } + + /// Append a user memory region and corresponding RawFd into the context object. + pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) { + self.regions.push(*region); + self.fds.push(fd); + } +} + +struct MasterInternal { + // Used to send requests to the slave. + main_sock: Endpoint<MasterReq>, + // Cached virtio features from the slave. + virtio_features: u64, + // Cached acked virtio features from the driver. + acked_virtio_features: u64, + // Cached vhost-user protocol features from the slave. + protocol_features: u64, + // Cached vhost-user protocol features. + acked_protocol_features: u64, + // Cached vhost-user protocol features are ready to use. + protocol_features_ready: bool, + // Cached maxinum number of queues supported from the slave. + max_queue_num: u64, + // Internal flag to mark failure state. + error: Option<i32>, +} + +impl MasterInternal { + fn send_request_header( + &mut self, + code: MasterReq, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + self.check_state()?; + let hdr = Self::new_request_header(code, 0); + self.main_sock.send_header(&hdr, fds)?; + Ok(hdr) + } + + fn send_request_with_body<T: Sized>( + &mut self, + code: MasterReq, + msg: &T, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32); + self.main_sock.send_message(&hdr, msg, fds)?; + Ok(hdr) + } + + fn send_request_with_payload<T: Sized, P: Sized>( + &mut self, + code: MasterReq, + msg: &T, + payload: &[P], + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + let len = mem::size_of::<T>() + payload.len() * mem::size_of::<P>(); + if len > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + if let Some(ref fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(VhostUserError::InvalidParam); + } + } + self.check_state()?; + + let hdr = Self::new_request_header(code, len as u32); + self.main_sock + .send_message_with_payload(&hdr, msg, payload, fds)?; + Ok(hdr) + } + + fn send_fd_for_vring( + &mut self, + code: MasterReq, + queue_index: usize, + fd: RawFd, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if queue_index as u64 >= self.max_queue_num { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. + // This flag is set when there is no file descriptor in the ancillary data. This signals + // that polling will be used instead of waiting for the call. + let msg = VhostUserU64::new(queue_index as u64); + let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32); + self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?; + Ok(hdr) + } + + fn recv_reply<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<T> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + Ok(body) + } + + fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let mut buf = vec![0; MAX_MSG_SIZE - mem::size_of::<T>()]; + let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + if !reply.is_reply_for(hdr) + || reply.get_size() as usize != mem::size_of::<T>() + bytes + || rfds.is_some() + || body.is_valid() + { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } else if bytes > MAX_MSG_SIZE - mem::size_of::<T>() { + return Err(VhostUserError::InvalidMessage); + } else if bytes < buf.len() { + // It's safe because we have checked the buffer size + unsafe { buf.set_len(bytes) }; + } + Ok((body, buf, rfds)) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { + if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0 + || !hdr.is_need_reply() + { + return Ok(()); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + if body.value != 0 { + return Err(VhostUserError::SlaveInternalError); + } + Ok(()) + } + + fn is_feature_mq_available(&self) -> bool { + self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 + } + + fn check_state(&self) -> VhostUserResult<()> { + match self.error { + Some(e) => Err(VhostUserError::SocketBroken( + std::io::Error::from_raw_os_error(e), + )), + None => Ok(()), + } + } + + #[inline] + fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { + // TODO: handle NEED_REPLY flag + VhostUserMsgHeader::new(request, 0, size) + } +} + +#[cfg(test)] +mod tests { + use super::super::connection::Listener; + use super::*; + + const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master"; + const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2"; + const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3"; + const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4"; + + fn create_pair(path: &str) -> (Master, Endpoint<MasterReq>) { + let listener = Listener::new(path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let master = Master::connect(path, 2).unwrap(); + let slave = listener.accept().unwrap().unwrap(); + (master, Endpoint::from_stream(slave)) + } + + #[test] + fn create_master() { + let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + let mut master = Master::connect(UNIX_SOCKET_MASTER, 2).unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap()); + + // Send two messages continuously + master.set_owner().unwrap(); + master.reset_owner().unwrap(); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + } + + #[test] + fn test_create_failure() { + let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); + let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err(); + assert!(Master::connect(UNIX_SOCKET_MASTER2, 2).is_err()); + + let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); + assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err()); + listener.set_nonblocking(true).unwrap(); + + let _master = Master::connect(UNIX_SOCKET_MASTER2, 2).unwrap(); + let _slave = listener.accept().unwrap().unwrap(); + } + + #[test] + fn test_features() { + let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, 0x15u64); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_features(0x15).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, 0x15); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = 0x15u32; + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_features().is_err()); + } + + #[test] + fn test_protocol_features() { + let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert!(rfds.is_none()); + + assert!(master.get_protocol_features().is_err()); + assert!(master + .set_protocol_features(VhostUserProtocolFeatures::all()) + .is_err()); + + let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(vfeatures); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, vfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_features(vfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, vfeatures); + + let pfeatures = VhostUserProtocolFeatures::all(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_protocol_features().unwrap(); + assert_eq!(features, pfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_protocol_features(pfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, pfeatures.bits()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_protocol_features().is_err()); + } + + #[test] + fn test_set_mem_table() { + // TODO + } + + #[test] + fn test_get_ring_num() { + // TODO + } +} diff --git a/vhost_rs/src/vhost_user/master_req_handler.rs b/vhost_rs/src/vhost_user/master_req_handler.rs new file mode 100644 index 0000000..cc82708 --- /dev/null +++ b/vhost_rs/src/vhost_user/master_req_handler.rs @@ -0,0 +1,258 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs to handle vhost-user requests from the slave to the master. + +use libc; +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result}; + +/// Trait to handle vhost-user requests from the slave to the master. +pub trait VhostUserMasterReqHandler { + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); + + /// Handle device configuration change notifications from the slave. + fn handle_config_change(&mut self) -> HandlerResult<()> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<()> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<()> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests from the slave. + fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<()> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } +} + +/// A vhost-user master request endpoint which relays all received requests from the slave to the +/// provided request handler. +pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { + // underlying Unix domain socket for communication + sub_sock: Endpoint<SlaveReq>, + tx_sock: UnixStream, + // the VirtIO backend device object + backend: Arc<Mutex<S>>, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { + /// Create a vhost-user slave request handler. + /// This opens a pair of connected anonymous sockets. + /// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD. + pub fn new(backend: Arc<Mutex<S>>) -> Result<Self> { + let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; + + Ok(MasterReqHandler { + sub_sock: Endpoint::<SlaveReq>::from_stream(rx), + tx_sock: tx, + backend, + error: None, + }) + } + + /// Get the raw fd to send to the slave as slave communication channel. + pub fn get_tx_raw_fd(&self) -> RawFd { + self.tx_sock.as_raw_fd() + } + + /// Mark endpoint as failed or normal state. + pub fn set_failed(&mut self, error: i32) { + self.error = Some(error); + } + + /// Receive and handle one incoming request message from the slave. + /// The caller needs to: + /// . serialize calls to this function + /// . decide what to do when errer happens + /// . optional recover from failure + pub fn handle_request(&mut self) -> Result<()> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.sub_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + let res = match hdr.get_code() { + SlaveReq::CONFIG_CHANGE_MSG => { + self.check_msg_size(&hdr, size, 0)?; + self.backend + .lock() + .unwrap() + .handle_config_change() + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_MAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .lock() + .unwrap() + .fs_slave_map(msg, rfds.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_UNMAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .lock() + .unwrap() + .fs_slave_unmap(msg) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_SYNC => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .lock() + .unwrap() + .fs_slave_sync(msg) + .map_err(Error::ReqHandlerError) + } + _ => Err(Error::InvalidMessage), + }; + + self.send_ack_message(&hdr, &res)?; + + res + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_msg_size( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + SlaveReq::FS_MAP => { + // Expect an fd set with a single fd. + match rfds { + None => Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != 1 { + Endpoint::<SlaveReq>::close_rfds(Some(fds)); + Err(Error::InvalidMessage) + } else { + Ok(Some(fds)) + } + } + } + } + _ => { + if rfds.is_some() { + Endpoint::<SlaveReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + buf: &'a [u8], + ) -> Result<&'a T> { + self.check_msg_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { &*(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<SlaveReq>, + ) -> Result<VhostUserMsgHeader<SlaveReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + mem::size_of::<T>() as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<SlaveReq>, + res: &Result<()>, + ) -> Result<()> { + if req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req)?; + let val = match res { + Ok(_) => 0, + Err(_) => 1, + }; + let msg = VhostUserU64::new(val); + self.sub_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } +} + +impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.sub_sock.as_raw_fd() + } +} diff --git a/vhost_rs/src/vhost_user/message.rs b/vhost_rs/src/vhost_user/message.rs new file mode 100644 index 0000000..834397f --- /dev/null +++ b/vhost_rs/src/vhost_user/message.rs @@ -0,0 +1,812 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Define communication messages for the vhost-user protocol. +//! +//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt). + +#![allow(dead_code)] +#![allow(non_camel_case_types)] + +use std::fmt::Debug; +use std::marker::PhantomData; + +use VringConfigData; + +/// The vhost-user specification uses a field of u32 to store message length. +/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain +/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead. +/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable +/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user +/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory, +/// 4K should be enough too because it can support 255 memory regions at most. +pub const MAX_MSG_SIZE: usize = 0x1000; + +/// The VhostUserMemory message has variable message size and variable number of attached file +/// descriptors. Each user memory region entry in the message payload occupies 32 bytes, +/// so setting maximum number of attached file descriptors based on the maximum message size. +/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further +/// reduce the maximum number... +// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32; +pub const MAX_ATTACHED_FD_ENTRIES: usize = 32; + +/// Starting position (inclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100; + +/// Ending position (exclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000; + +/// Maximum number of vrings supported. +pub const VHOST_USER_MAX_VRINGS: u64 = 0xFFu64; + +pub(super) trait Req: + Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32> +{ + fn is_valid(&self) -> bool; +} + +/// Type of requests sending from masters to slaves. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum MasterReq { + /// Null operation. + NOOP = 0, + /// Get from the underlying vhost implementation the features bit mask. + GET_FEATURES = 1, + /// Enable features in the underlying vhost implementation using a bit mask. + SET_FEATURES = 2, + /// Set the current Master as an owner of the session. + SET_OWNER = 3, + /// No longer used. + RESET_OWNER = 4, + /// Set the memory map regions on the slave so it can translate the vring addresses. + SET_MEM_TABLE = 5, + /// Set logging shared memory space. + SET_LOG_BASE = 6, + /// Set the logging file descriptor, which is passed as ancillary data. + SET_LOG_FD = 7, + /// Set the size of the queue. + SET_VRING_NUM = 8, + /// Set the addresses of the different aspects of the vring. + SET_VRING_ADDR = 9, + /// Set the base offset in the available vring. + SET_VRING_BASE = 10, + /// Get the available vring base offset. + GET_VRING_BASE = 11, + /// Set the event file descriptor for adding buffers to the vring. + SET_VRING_KICK = 12, + /// Set the event file descriptor to signal when buffers are used. + SET_VRING_CALL = 13, + /// Set the event file descriptor to signal when error occurs. + SET_VRING_ERR = 14, + /// Get the protocol feature bit mask from the underlying vhost implementation. + GET_PROTOCOL_FEATURES = 15, + /// Enable protocol features in the underlying vhost implementation. + SET_PROTOCOL_FEATURES = 16, + /// Query how many queues the backend supports. + GET_QUEUE_NUM = 17, + /// Signal slave to enable or disable corresponding vring. + SET_VRING_ENABLE = 18, + /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated + /// for guest that does not support GUEST_ANNOUNCE. + SEND_RARP = 19, + /// Set host MTU value exposed to the guest. + NET_SET_MTU = 20, + /// Set the socket file descriptor for slave initiated requests. + SET_SLAVE_REQ_FD = 21, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 22, + /// Set the endianness of a VQ for legacy devices. + SET_VRING_ENDIAN = 23, + /// Fetch the contents of the virtio device configuration space. + GET_CONFIG = 24, + /// Change the contents of the virtio device configuration space. + SET_CONFIG = 25, + /// Create a session for crypto operation. + CREATE_CRYPTO_SESSION = 26, + /// Close a session for crypto operation. + CLOSE_CRYPTO_SESSION = 27, + /// Advise slave that a migration with postcopy enabled is underway. + POSTCOPY_ADVISE = 28, + /// Advise slave that a transition to postcopy mode has happened. + POSTCOPY_LISTEN = 29, + /// Advise that postcopy migration has now completed. + POSTCOPY_END = 30, + /// Get a shared buffer from slave. + GET_INFLIGHT_FD = 31, + /// Send the shared inflight buffer back to slave + SET_INFLIGHT_FD = 32, + /// Upper bound of valid commands. + MAX_CMD = 33, +} + +impl Into<u32> for MasterReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for MasterReq { + fn is_valid(&self) -> bool { + (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + } +} + +/// Type of requests sending from slaves to masters. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum SlaveReq { + /// Null operation. + NOOP = 0, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 1, + /// Notify that the virtio device's configuration space has changed. + CONFIG_CHANGE_MSG = 2, + /// Set host notifier for a specified queue. + VRING_HOST_NOTIFIER_MSG = 3, + /// Virtio-fs draft: map file content into the window. + FS_MAP = 4, + /// Virtio-fs draft: unmap file content from the window. + FS_UNMAP = 5, + /// Virtio-fs draft: sync file content. + FS_SYNC = 6, + /// Upper bound of valid commands. + MAX_CMD = 7, +} + +impl Into<u32> for SlaveReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for SlaveReq { + fn is_valid(&self) -> bool { + (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD) + } +} + +/// Vhost message Validator. +pub trait VhostUserMsgValidator { + /// Validate message syntax only. + /// It doesn't validate message semantics such as protocol version number and dependency + /// on feature flags etc. + fn is_valid(&self) -> bool { + true + } +} + +bitflags! { + /// Common message flags for vhost-user requests and replies. + pub struct VhostUserHeaderFlag: u32 { + /// Bits[0..2] is message version number. + const VERSION = 0x3; + /// Mark message as reply. + const REPLY = 0x4; + /// Sender anticipates a reply message from the peer. + const NEED_REPLY = 0x8; + /// All valid bits. + const ALL_FLAGS = 0xc; + /// All reserved bits. + const RESERVED_BITS = !0xf; + } +} + +/// Common message header for vhost-user requests and replies. +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[allow(safe_packed_borrows)] +#[repr(packed)] +#[derive(Debug, Clone, Copy, PartialEq)] +pub(super) struct VhostUserMsgHeader<R: Req> { + request: u32, + flags: u32, + size: u32, + _r: PhantomData<R>, +} + +impl<R: Req> VhostUserMsgHeader<R> { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + // Default to protocol version 1 + let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1; + VhostUserMsgHeader { + request: request.into(), + flags: fl, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> R { + // It's safe because R is marked as repr(u32). + unsafe { std::mem::transmute_copy::<u32, R>(&self.request) } + } + + /// Set message type. + pub fn set_code(&mut self, request: R) { + self.request = request.into(); + } + + /// Get message version number. + pub fn get_version(&self) -> u32 { + self.flags & 0x3 + } + + /// Set message version number. + pub fn set_version(&mut self, ver: u32) { + self.flags &= !0x3; + self.flags |= ver & 0x3; + } + + /// Check whether it's a reply message. + pub fn is_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0 + } + + /// Mark message as reply. + pub fn set_reply(&mut self, is_reply: bool) { + if is_reply { + self.flags |= VhostUserHeaderFlag::REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::REPLY.bits(); + } + } + + /// Check whether reply for this message is requested. + pub fn is_need_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0 + } + + /// Mark that reply for this message is needed. + pub fn set_need_reply(&mut self, need_reply: bool) { + if need_reply { + self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits(); + } + } + + /// Check whether it's the reply message for the request `req`. + pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool { + self.is_reply() && !req.is_reply() && self.get_code() == req.get_code() + } + + /// Get message size. + pub fn get_size(&self) -> u32 { + self.size + } + + /// Set message size. + pub fn set_size(&mut self, size: u32) { + self.size = size; + } +} + +impl<R: Req> Default for VhostUserMsgHeader<R> { + fn default() -> Self { + VhostUserMsgHeader { + request: 0, + flags: 0x1, + size: 0, + _r: PhantomData, + } + } +} + +impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if !self.get_code().is_valid() { + return false; + } else if self.size as usize > MAX_MSG_SIZE { + return false; + } else if self.get_version() != 0x1 { + return false; + } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 { + return false; + } + true + } +} + +bitflags! { + /// Transport specific flags in VirtIO feature set defined by vhost-user. + pub struct VhostUserVirtioFeatures: u64 { + /// Feature flag for the protocol feature. + const PROTOCOL_FEATURES = 0x4000_0000; + } +} + +bitflags! { + /// Vhost-user protocol feature flags. + pub struct VhostUserProtocolFeatures: u64 { + /// Support multiple queues. + const MQ = 0x0000_0001; + /// Support logging through shared memory fd. + const LOG_SHMFD = 0x0000_0002; + /// Support broadcasting fake RARP packet. + const RARP = 0x0000_0004; + /// Support sending reply messages for requests with NEED_REPLY flag set. + const REPLY_ACK = 0x0000_0008; + /// Support setting MTU for virtio-net devices. + const MTU = 0x0000_0010; + /// Allow the slave to send requests to the master by an optional communication channel. + const SLAVE_REQ = 0x0000_0020; + /// Support setting slave endian by SET_VRING_ENDIAN. + const CROSS_ENDIAN = 0x0000_0040; + /// Support crypto operations. + const CRYPTO_SESSION = 0x0000_0080; + /// Support sending userfault_fd from slaves to masters. + const PAGEFAULT = 0x0000_0100; + /// Support Virtio device configuration. + const CONFIG = 0x0000_0200; + /// Allow the slave to send fds (at most 8 descriptors in each message) to the master. + const SLAVE_SEND_FD = 0x0000_0400; + /// Allow the slave to register a host notifier. + const HOST_NOTIFIER = 0x0000_0800; + } +} + +/// A generic message to encapsulate a 64-bit value. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserU64 { + /// The encapsulated 64-bit common value. + pub value: u64, +} + +impl VhostUserU64 { + /// Create a new instance. + pub fn new(value: u64) -> Self { + VhostUserU64 { value } + } +} + +impl VhostUserMsgValidator for VhostUserU64 {} + +/// Memory region descriptor for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserMemory { + /// Number of memory regions in the payload. + pub num_regions: u32, + /// Padding for alignment. + pub padding1: u32, +} + +impl VhostUserMemory { + /// Create a new instance. + pub fn new(cnt: u32) -> Self { + VhostUserMemory { + num_regions: cnt, + padding1: 0, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemory { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if self.padding1 != 0 { + return false; + } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 { + return false; + } + true + } +} + +/// Memory region descriptors as payload for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserMemoryRegion { + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, +} + +impl VhostUserMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserMemoryRegion { + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemoryRegion { + fn is_valid(&self) -> bool { + if self.memory_size == 0 + || self.guest_phys_addr.checked_add(self.memory_size).is_none() + || self.user_addr.checked_add(self.memory_size).is_none() + || self.mmap_offset.checked_add(self.memory_size).is_none() + { + return false; + } + true + } +} + +/// Payload of the VhostUserMemory message. +pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>; + +/// Vring state descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringState { + /// Vring index. + pub index: u32, + /// A common 32bit value to encapsulate vring state etc. + pub num: u32, +} + +impl VhostUserVringState { + /// Create a new instance. + pub fn new(index: u32, num: u32) -> Self { + VhostUserVringState { index, num } + } +} + +impl VhostUserMsgValidator for VhostUserVringState {} + +bitflags! { + /// Flags for vring address. + pub struct VhostUserVringAddrFlags: u32 { + /// Support log of vring operations. + /// Modifications to "used" vring should be logged. + const VHOST_VRING_F_LOG = 0x1; + } +} + +/// Vring address descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringAddr { + /// Vring index. + pub index: u32, + /// Vring flags defined by VhostUserVringAddrFlags. + pub flags: u32, + /// Ring address of the vring descriptor table. + pub descriptor: u64, + /// Ring address of the vring used ring. + pub used: u64, + /// Ring address of the vring available ring. + pub available: u64, + /// Guest address for logging. + pub log: u64, +} + +impl VhostUserVringAddr { + /// Create a new instance. + pub fn new( + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Self { + VhostUserVringAddr { + index, + flags: flags.bits(), + descriptor, + used, + available, + log, + } + } + + /// Create a new instance from `VringConfigData`. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))] + pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self { + let log_addr = config_data.log_addr.unwrap_or(0); + VhostUserVringAddr { + index, + flags: config_data.flags, + descriptor: config_data.desc_table_addr, + used: config_data.used_ring_addr, + available: config_data.avail_ring_addr, + log: log_addr, + } + } +} + +impl VhostUserMsgValidator for VhostUserVringAddr { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 { + return false; + } else if self.descriptor & 0xf != 0 { + return false; + } else if self.available & 0x1 != 0 { + return false; + } else if self.used & 0x3 != 0 { + return false; + } + true + } +} + +bitflags! { + /// Flags for the device configuration message. + pub struct VhostUserConfigFlags: u32 { + /// TODO: seems the vhost-user spec has refined the definition, EMPTY is removed. + const EMPTY = 0x0; + /// Vhost master messages used for writable fields + const WRITABLE = 0x1; + /// Mark that message is part of an ongoing live-migration operation. + const LIVE_MIGRATION = 0x2; + } +} + +/// Message to read/write device configuration space. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserConfig { + /// Offset of virtio device's configuration space. + pub offset: u32, + /// Configuration space access size in bytes. + pub size: u32, + /// Flags for the device configuration operation. + pub flags: u32, +} + +impl VhostUserConfig { + /// Create a new instance. + pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self { + VhostUserConfig { + offset, + size, + flags: flags.bits(), + } + } +} + +impl VhostUserMsgValidator for VhostUserConfig { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { + return false; + } else if self.offset < VHOST_USER_CONFIG_OFFSET + || self.offset >= VHOST_USER_CONFIG_SIZE + || self.size == 0 + || self.size > (VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET) + || self.size + self.offset > VHOST_USER_CONFIG_SIZE + { + return false; + } + true + } +} + +/// Payload for the VhostUserConfig message. +pub type VhostUserConfigPayload = Vec<u8>; + +/* + * TODO: support dirty log, live migration and IOTLB operations. +#[repr(packed)] +pub struct VhostUserVringArea { + pub index: u32, + pub flags: u32, + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserLog { + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserIotlb { + pub iova: u64, + pub size: u64, + pub user_addr: u64, + pub permission: u8, + pub optype: u8, +} +*/ + +bitflags! { + #[derive(Default)] + /// Flags for virtio-fs slave messages. + pub struct VhostUserFSSlaveMsgFlags: u64 { + /// Empty permission. + const EMPTY = 0x0; + /// Read permission. + const MAP_R = 0x1; + /// Write permission. + const MAP_W = 0x2; + } +} + +/// Max entries in one virtio-fs slave request. +const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; + +/// Slave request message to update the MMIO window. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserFSSlaveMsg { + /// TODO: + pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// TODO: + pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Size of region to map. + pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Flags for the mmap operation + pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES], +} + +impl VhostUserMsgValidator for VhostUserFSSlaveMsg { + fn is_valid(&self) -> bool { + for i in 0..VHOST_USER_FS_SLAVE_ENTRIES { + if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0 + || self.fd_offset[i].checked_add(self.len[i]).is_none() + || self.cache_offset[i].checked_add(self.len[i]).is_none() + { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn check_request_code() { + let code = MasterReq::NOOP; + assert!(!code.is_valid()); + let code = MasterReq::MAX_CMD; + assert!(!code.is_valid()); + let code = MasterReq::GET_FEATURES; + assert!(code.is_valid()); + } + + #[test] + fn msg_header_ops() { + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); + assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + hdr.set_code(MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + + assert_eq!(hdr.get_version(), 0x1); + + assert_eq!(hdr.is_reply(), false); + hdr.set_reply(true); + assert_eq!(hdr.is_reply(), true); + hdr.set_reply(false); + + assert_eq!(hdr.is_need_reply(), false); + hdr.set_need_reply(true); + assert_eq!(hdr.is_need_reply(), true); + hdr.set_need_reply(false); + + assert_eq!(hdr.get_size(), 0x100); + hdr.set_size(0x200); + assert_eq!(hdr.get_size(), 0x200); + + assert_eq!(hdr.is_need_reply(), false); + assert_eq!(hdr.is_reply(), false); + assert_eq!(hdr.get_version(), 0x1); + + // Check message length + assert!(hdr.is_valid()); + hdr.set_size(0x2000); + assert!(!hdr.is_valid()); + hdr.set_size(0x100); + assert_eq!(hdr.get_size(), 0x100); + assert!(hdr.is_valid()); + hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32); + assert!(hdr.is_valid()); + hdr.set_size(0x0); + assert!(hdr.is_valid()); + + // Check version + hdr.set_version(0x0); + assert!(!hdr.is_valid()); + hdr.set_version(0x2); + assert!(!hdr.is_valid()); + hdr.set_version(0x1); + assert!(hdr.is_valid()); + } + + #[test] + fn check_user_memory() { + let mut msg = VhostUserMemory::new(1); + assert!(msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + assert!(msg.is_valid()); + + msg.num_regions += 1; + assert!(!msg.is_valid()); + msg.num_regions = 0xFFFFFFFF; + assert!(!msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + msg.padding1 = 1; + assert!(!msg.is_valid()); + } + + #[test] + fn check_user_memory_region() { + let mut msg = VhostUserMemoryRegion { + guest_phys_addr: 0, + memory_size: 0x1000, + user_addr: 0, + mmap_offset: 0, + }; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFF000; + assert!(!msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; + msg.memory_size = 0; + assert!(!msg.is_valid()); + } + + #[test] + fn check_user_vring_addr() { + let mut msg = + VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0); + assert!(msg.is_valid()); + + msg.descriptor = 1; + assert!(!msg.is_valid()); + msg.descriptor = 0; + + msg.available = 1; + assert!(!msg.is_valid()); + msg.available = 0; + + msg.used = 1; + assert!(!msg.is_valid()); + msg.used = 0; + + msg.flags |= 0x80000000; + assert!(!msg.is_valid()); + msg.flags &= !0x80000000; + } + + #[test] + fn check_user_config_msg() { + let mut msg = VhostUserConfig::new( + VHOST_USER_CONFIG_OFFSET, + VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET, + VhostUserConfigFlags::EMPTY, + ); + + assert!(msg.is_valid()); + msg.size = 0; + assert!(!msg.is_valid()); + msg.size = 1; + assert!(msg.is_valid()); + msg.offset = 0; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE - 1; + assert!(msg.is_valid()); + msg.size = 2; + assert!(!msg.is_valid()); + msg.size = 1; + msg.flags |= VhostUserConfigFlags::WRITABLE.bits(); + assert!(msg.is_valid()); + msg.flags |= 0x4; + assert!(!msg.is_valid()); + } +} diff --git a/vhost_rs/src/vhost_user/mod.rs b/vhost_rs/src/vhost_user/mod.rs new file mode 100644 index 0000000..af2c6d1 --- /dev/null +++ b/vhost_rs/src/vhost_user/mod.rs @@ -0,0 +1,251 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux +//! Kernel. The protocol defines two sides of the communication, master and slave. Master is +//! the application that shares its virtqueues. Slave is the consumer of the virtqueues. +//! +//! The communication channel between the master and the slave includes two sub channels. One is +//! used to send requests from the master to the slave and optional replies from the slave to the +//! master. This sub channel is created on master startup by connecting to the slave service +//! endpoint. The other is used to send requests from the slave to the master and optional replies +//! from the master to the slave. This sub channel is created by the master issuing a +//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor. +//! +//! Unix domain socket is used as the underlying communication channel because the master needs to +//! send file descriptors to the slave. +//! +//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an +//! equivalent ioctl to the kernel implementation. + +use libc; +use std::io::Error as IOError; + +mod connection; +pub mod message; +pub use self::connection::Listener; +#[cfg(feature = "vhost-user-master")] +mod master; +#[cfg(feature = "vhost-user-master")] +pub use self::master::{Master, VhostUserMaster}; +#[cfg(feature = "vhost-user-master")] +mod master_req_handler; +#[cfg(feature = "vhost-user-master")] +pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler}; + +#[cfg(feature = "vhost-user-slave")] +mod slave; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave::SlaveListener; +#[cfg(feature = "vhost-user-slave")] +mod slave_req_handler; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler}; + +pub mod sock_ctrl_msg; + +/// Errors for vhost-user operations +#[derive(Debug)] +pub enum Error { + /// Invalid parameters. + InvalidParam, + /// Unsupported operations due to that the protocol feature hasn't been negotiated. + InvalidOperation, + /// Invalid message format, flag or content. + InvalidMessage, + /// Only part of a message have been sent or received successfully + PartialMessage, + /// Message is too large + OversizedMsg, + /// Fd array in question is too big or too small + IncorrectFds, + /// Can't connect to peer. + SocketConnect(std::io::Error), + /// Generic socket errors. + SocketError(std::io::Error), + /// The socket is broken or has been closed. + SocketBroken(std::io::Error), + /// Should retry the socket operation again. + SocketRetry(std::io::Error), + /// Failure from the slave side. + SlaveInternalError, + /// Virtio/protocol features mismatch. + FeatureMismatch, + /// Error from request handler + ReqHandlerError(IOError), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::InvalidParam => write!(f, "invalid parameters"), + Error::InvalidOperation => write!(f, "invalid operation"), + Error::InvalidMessage => write!(f, "invalid message"), + Error::PartialMessage => write!(f, "partial message"), + Error::OversizedMsg => write!(f, "oversized message"), + Error::IncorrectFds => write!(f, "wrong number of attached fds"), + Error::SocketError(e) => write!(f, "socket error: {}", e), + Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e), + Error::SocketBroken(e) => write!(f, "socket is broken: {}", e), + Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e), + Error::SlaveInternalError => write!(f, "slave internal error"), + Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"), + Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e), + } + } +} + +impl Error { + /// Determine whether to rebuild the underline communication channel. + pub fn should_reconnect(&self) -> bool { + match *self { + // Should reconnect because it may be caused by temporary network errors. + Error::PartialMessage => true, + // Should reconnect because the underline socket is broken. + Error::SocketBroken(_) => true, + // Slave internal error, hope it recovers on reconnect. + Error::SlaveInternalError => true, + // Should just retry the IO operation instead of rebuilding the underline connection. + Error::SocketRetry(_) => false, + Error::InvalidParam | Error::InvalidOperation => false, + Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false, + Error::SocketError(_) | Error::SocketConnect(_) => false, + Error::FeatureMismatch => false, + Error::ReqHandlerError(_) => false, + } + } +} + +impl std::convert::From<vmm_sys_util::errno::Error> for Error { + /// Convert raw socket errors into meaningful vhost-user errors. + /// + /// The vmm_sys_util::errno::Error is a simple wrapper over the raw errno, which doesn't means much + /// to the vhost-user connection manager. So convert it into meaningful errors to simplify + /// the connection manager logic. + /// + /// # Return: + /// * - Error::SocketRetry: temporary error caused by signals or short of resources. + /// * - Error::SocketBroken: the underline socket is broken. + /// * - Error::SocketError: other socket related errors. + #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux + fn from(err: vmm_sys_util::errno::Error) -> Self { + match err.errno() { + // The socket is marked nonblocking and the requested operation would block. + libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)), + // The socket is marked nonblocking and the requested operation would block. + libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)), + // A signal occurred before any data was transmitted + libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)), + // The output queue for a network interface was full. This generally indicates + // that the interface has stopped sending, but may be caused by transient congestion. + libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)), + // No memory available. + libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)), + // Connection reset by peer. + libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)), + // The local end has been shut down on a connection oriented socket. In this case the + // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set. + libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)), + // Write permission is denied on the destination socket file, or search permission is + // denied for one of the directories the path prefix. + libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)), + // Catch all other errors + e => Error::SocketError(IOError::from_raw_os_error(e)), + } + } +} + +/// Result of vhost-user operations +pub type Result<T> = std::result::Result<T, Error>; + +/// Result of request handler. +pub type HandlerResult<T> = std::result::Result<T, IOError>; + +#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +mod dummy_slave; + +#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +mod tests { + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; + use super::message::*; + use super::*; + use crate::backend::VhostBackend; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + + fn create_slave<S: VhostUserSlaveReqHandler>( + path: &str, + backend: Arc<Mutex<S>>, + ) -> (Master, SlaveReqHandler<S>) { + let mut slave_listener = SlaveListener::new(path, true, backend).unwrap(); + let master = Master::connect(path).unwrap(); + (master, slave_listener.accept().unwrap().unwrap()) + } + + #[test] + fn create_dummy_slave() { + let mut slave = DummySlaveReqHandler::new(); + + slave.set_owner().unwrap(); + assert!(slave.set_owner().is_err()); + } + + #[test] + fn test_set_owner() { + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = + create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone()); + + assert_eq!(slave_be.lock().unwrap().owned, false); + master.set_owner().unwrap(); + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + master.set_owner().unwrap(); + assert!(slave.handle_request().is_err()); + assert_eq!(slave_be.lock().unwrap().owned, true); + } + + #[test] + fn test_set_features() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = + create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } +} diff --git a/vhost_rs/src/vhost_user/slave.rs b/vhost_rs/src/vhost_user/slave.rs new file mode 100644 index 0000000..3f097b8 --- /dev/null +++ b/vhost_rs/src/vhost_user/slave.rs @@ -0,0 +1,48 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs for vhost-user slave. + +use std::sync::{Arc, Mutex}; + +use super::connection::{Endpoint, Listener}; +use super::message::*; +use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; + +/// Vhost-user slave side connection listener. +pub struct SlaveListener<S: VhostUserSlaveReqHandler> { + listener: Listener, + backend: Option<Arc<Mutex<S>>>, +} + +/// Sets up a listener for incoming master connections, and handles construction +/// of a Slave on success. +impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { + /// Create a unix domain socket for incoming master connections. + /// + /// Be careful, the file at `path` will be unlinked if unlink is true + pub fn new(path: &str, unlink: bool, backend: Arc<Mutex<S>>) -> Result<Self> { + Ok(SlaveListener { + listener: Listener::new(path, unlink)?, + backend: Some(backend), + }) + } + + /// Accept an incoming connection from the master, returning Some(Slave) on + /// success, or None if the socket is nonblocking and no incoming connection + /// was detected + pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> { + if let Some(fd) = self.listener.accept()? { + return Ok(Some(SlaveReqHandler::new( + Endpoint::<MasterReq>::from_stream(fd), + self.backend.take().unwrap(), + ))); + } + Ok(None) + } + + /// Change blocking status on the listener. + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.listener.set_nonblocking(block) + } +} diff --git a/vhost_rs/src/vhost_user/slave_req_handler.rs b/vhost_rs/src/vhost_user/slave_req_handler.rs new file mode 100644 index 0000000..934c6d4 --- /dev/null +++ b/vhost_rs/src/vhost_user/slave_req_handler.rs @@ -0,0 +1,582 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs to handle vhost-user requests from the master to the slave. + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::slice; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, Result}; + +/// Trait to handle vhost-user requests from the master to the slave. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandler { + fn set_owner(&mut self) -> Result<()>; + fn reset_owner(&mut self) -> Result<()>; + fn get_features(&mut self) -> Result<u64>; + fn set_features(&mut self, features: u64) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &mut self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&mut self, features: u64) -> Result<()>; + fn get_queue_num(&mut self) -> Result<u64>; + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>>; + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; +} + +/// A vhost-user slave endpoint which relays all received requests from the +/// master to the virtio backend device object. +/// +/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain +/// Socket, so it gets simpler to recover from disconnect. +pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { + // underlying Unix domain socket for communication + main_sock: Endpoint<MasterReq>, + // the vhost-user backend device object + backend: Arc<Mutex<S>>, + + virtio_features: u64, + acked_virtio_features: u64, + protocol_features: VhostUserProtocolFeatures, + acked_protocol_features: u64, + + // sending ack for messages without payload + reply_ack_enabled: bool, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { + /// Create a vhost-user slave endpoint. + pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self { + SlaveReqHandler { + main_sock, + backend, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: VhostUserProtocolFeatures::empty(), + acked_protocol_features: 0, + reply_ack_enabled: false, + error: None, + } + } + + /// Create a new vhost-user slave endpoint. + /// + /// # Arguments + /// * - `path` - path of Unix domain socket listener to connect to + /// * - `backend` - handler for requests from the master to the slave + pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> { + Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&mut self, error: i32) { + self.error = Some(error); + } + + /// Receive and handle one incoming request message from the master. + /// The caller needs to: + /// . serialize calls to this function + /// . decide what to do when error happens + /// . optional recover from failure + pub fn handle_request(&mut self) -> Result<()> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.main_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + let (size2, rbuf) = self.main_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + match hdr.get_code() { + MasterReq::SET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.lock().unwrap().set_owner()?; + } + MasterReq::RESET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.lock().unwrap().reset_owner()?; + } + MasterReq::GET_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.lock().unwrap().get_features()?; + let msg = VhostUserU64::new(features); + self.send_reply_message(&hdr, &msg)?; + self.virtio_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend.lock().unwrap().set_features(msg.value)?; + self.acked_virtio_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::SET_MEM_TABLE => { + let res = self.set_mem_table(&hdr, size, &buf, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_NUM => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self + .backend + .lock() + .unwrap() + .set_vring_num(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ADDR => { + let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; + let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.lock().unwrap().set_vring_addr( + msg.index, + flags, + msg.descriptor, + msg.used, + msg.available, + msg.log, + ); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self + .backend + .lock() + .unwrap() + .set_vring_base(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?; + self.send_reply_message(&hdr, &reply)?; + } + MasterReq::SET_VRING_CALL => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.lock().unwrap().set_vring_call(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_KICK => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.lock().unwrap().set_vring_kick(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ERR => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.lock().unwrap().set_vring_err(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_PROTOCOL_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.lock().unwrap().get_protocol_features()?; + let msg = VhostUserU64::new(features.bits()); + self.send_reply_message(&hdr, &msg)?; + self.protocol_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_PROTOCOL_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend + .lock() + .unwrap() + .set_protocol_features(msg.value)?; + self.acked_protocol_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::GET_QUEUE_NUM => { + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.lock().unwrap().get_queue_num()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + MasterReq::SET_VRING_ENABLE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 + && msg.index > 0 + { + return Err(Error::InvalidOperation); + } + let enable = match msg.num { + 1 => true, + 0 => false, + _ => return Err(Error::InvalidParam), + }; + + let res = self + .backend + .lock() + .unwrap() + .set_vring_enable(msg.index, enable); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, mem::size_of::<VhostUserConfig>())?; + self.get_config(&hdr, &buf)?; + } + MasterReq::SET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.set_config(&hdr, size, &buf)?; + } + _ => { + return Err(Error::InvalidMessage); + } + } + Ok(()) + } + + fn set_mem_table( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<()> { + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + + // check message size is consistent + let hdrsize = mem::size_of::<VhostUserMemory>(); + if size < hdrsize { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; + if !msg.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + + // validate number of fds matching number of memory regions + let fds = match rfds { + None => return Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != msg.num_regions as usize { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + fds + } + }; + + // Validate memory regions + let regions = unsafe { + slice::from_raw_parts( + buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, + msg.num_regions as usize, + ) + }; + for region in regions.iter() { + if !region.is_valid() { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + + self.backend.lock().unwrap().set_mem_table(®ions, &fds) + } + + fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + let flags = match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self + .backend + .lock() + .unwrap() + .get_config(msg.offset, msg.size, flags); + + // vhost-user slave's payload size MUST match master's request + // on success, uses zero length of payload to indicate an error + // to vhost-user master. + match res { + Ok(ref buf) if buf.len() == msg.size as usize => { + let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); + self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?; + } + Ok(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + Err(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + } + Ok(()) + } + + fn set_config( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<()> { + if size < mem::size_of::<VhostUserConfig>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags: VhostUserConfigFlags; + match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => flags = val, + None => return Err(Error::InvalidMessage), + } + + let res = self + .backend + .lock() + .unwrap() + .set_config(msg.offset, buf, flags); + self.send_ack_message(&hdr, res)?; + Ok(()) + } + + fn handle_vring_fd_request( + &mut self, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<(u8, Option<RawFd>)> { + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the + // invalid FD flag. This flag is set when there is no file descriptor + // in the ancillary data. This signals that polling will be used + // instead of waiting for the call. + let nofd = match msg.value & 0x100u64 { + 0x100u64 => true, + _ => false, + }; + + let mut rfd = None; + match rfds { + Some(fds) => { + if !nofd && fds.len() == 1 { + rfd = Some(fds[0]); + } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + None => { + if !nofd { + return Err(Error::InvalidMessage); + } + } + } + Ok((msg.value as u8, rfd)) + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_request_size( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + MasterReq::SET_MEM_TABLE => Ok(rfds), + MasterReq::SET_VRING_CALL => Ok(rfds), + MasterReq::SET_VRING_KICK => Ok(rfds), + MasterReq::SET_VRING_ERR => Ok(rfds), + MasterReq::SET_LOG_BASE => Ok(rfds), + MasterReq::SET_LOG_FD => Ok(rfds), + MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), + MasterReq::SET_INFLIGHT_FD => Ok(rfds), + _ => { + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &'a [u8], + ) -> Result<&'a T> { + self.check_request_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { &*(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn update_reply_ack_flag(&mut self) { + let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let pflag = VhostUserProtocolFeatures::REPLY_ACK; + if (self.virtio_features & vflag) != 0 + && (self.acked_virtio_features & vflag) != 0 + && self.protocol_features.contains(pflag) + && (self.acked_protocol_features & pflag.bits()) != 0 + { + self.reply_ack_enabled = true; + } else { + self.reply_ack_enabled = false; + } + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<MasterReq>, + ) -> Result<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + mem::size_of::<T>() as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + res: Result<()>, + ) -> Result<()> { + if self.reply_ack_enabled { + let hdr = self.new_reply_header::<VhostUserU64>(req)?; + let val = match res { + Ok(_) => 0, + Err(_) => 1, + }; + let msg = VhostUserU64::new(val); + self.main_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } + + fn send_reply_message<T>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req)?; + self.main_sock.send_message(&hdr, msg, None)?; + Ok(()) + } + + fn send_reply_with_payload<T, P>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + payload: &[P], + ) -> Result<()> + where + T: Sized, + P: Sized, + { + let hdr = self.new_reply_header::<T>(req)?; + self.main_sock + .send_message_with_payload(&hdr, msg, payload, None)?; + Ok(()) + } +} + +impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.main_sock.as_raw_fd() + } +} diff --git a/vhost_rs/src/vhost_user/sock_ctrl_msg.rs b/vhost_rs/src/vhost_user/sock_ctrl_msg.rs new file mode 100644 index 0000000..76d760f --- /dev/null +++ b/vhost_rs/src/vhost_user/sock_ctrl_msg.rs @@ -0,0 +1,464 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Used to send and receive messages with file descriptors on sockets that accept control messages +//! (e.g. Unix domain sockets). + +// TODO: move this file into the vmm-sys-util crate + +use std::fs::File; +use std::mem::size_of; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::{UnixDatagram, UnixStream}; +use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; + +use libc::{ + c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, +}; +use vmm_sys_util::errno::{Error, Result}; + +// Each of the following macros performs the same function as their C counterparts. They are each +// macros because they are used to size statically allocated arrays. + +macro_rules! CMSG_ALIGN { + ($len:expr) => { + (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1) + }; +} + +macro_rules! CMSG_SPACE { + ($len:expr) => { + size_of::<cmsghdr>() + CMSG_ALIGN!($len) + }; +} + +macro_rules! CMSG_LEN { + ($len:expr) => { + size_of::<cmsghdr>() + ($len) + }; +} + +// This function (macro in the C version) is not used in any compile time constant slots, so is just +// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this +// module supports. +#[allow(non_snake_case)] +#[inline(always)] +fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { + // Essentially returns a pointer to just past the header. + cmsg_buffer.wrapping_offset(1) as *mut RawFd +} + +// This function is like CMSG_NEXT, but safer because it reads only from references, although it +// does some pointer arithmetic on cmsg_ptr. +#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))] +fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { + let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr; + if next_cmsg + .wrapping_offset(1) + .wrapping_sub(msghdr.msg_control as usize) as usize + > msghdr.msg_controllen + { + null_mut() + } else { + next_cmsg + } +} + +const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32); + +enum CmsgBuffer { + Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]), + Heap(Box<[cmsghdr]>), +} + +impl CmsgBuffer { + fn with_capacity(capacity: usize) -> CmsgBuffer { + let cap_in_cmsghdr_units = + (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>(); + if capacity <= CMSG_BUFFER_INLINE_CAPACITY { + CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]) + } else { + CmsgBuffer::Heap( + vec![ + cmsghdr { + cmsg_len: 0, + cmsg_level: 0, + cmsg_type: 0, + }; + cap_in_cmsghdr_units + ] + .into_boxed_slice(), + ) + } + } + + fn as_mut_ptr(&mut self) -> *mut cmsghdr { + match self { + CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr, + CmsgBuffer::Heap(a) => a.as_mut_ptr(), + } + } +} + +fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> { + let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovecs = Vec::with_capacity(out_data.len()); + for data in out_data { + iovecs.push(iovec { + iov_base: data.as_ptr() as *mut c_void, + iov_len: data.size(), + }); + } + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: iovecs.as_mut_ptr(), + msg_iovlen: iovecs.len(), + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !out_fds.is_empty() { + let cmsg = cmsghdr { + cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()), + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + }; + unsafe { + // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. + write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); + // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() + // file descriptors. + copy_nonoverlapping( + out_fds.as_ptr(), + CMSG_DATA(cmsg_buffer.as_mut_ptr()), + out_fds.len(), + ); + } + + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; + + if write_count == -1 { + Err(Error::last()) + } else { + Ok(write_count as usize) + } +} + +fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: iovecs.as_mut_ptr(), + msg_iovlen: iovecs.len(), + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !in_fds.is_empty() { + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) }; + + if total_read == -1 { + return Err(Error::last()); + } + + if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() { + return Ok((0, 0)); + } + + let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; + let mut in_fds_count = 0; + while !cmsg_ptr.is_null() { + // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that + // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. + let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; + + if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { + let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>(); + unsafe { + copy_nonoverlapping( + CMSG_DATA(cmsg_ptr), + in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), + fd_count, + ); + } + in_fds_count += fd_count; + } + + cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + } + + Ok((total_read as usize, in_fds_count)) +} + +/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and +/// `recvmsg`. +pub trait ScmSocket { + /// Gets the file descriptor of this socket. + fn socket_fd(&self) -> RawFd; + + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> { + self.send_with_fds(&[buf], &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `bufs` - A list of data buffer to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> { + raw_sendmsg(self.socket_fd(), bufs, fds) + } + + /// Receives data and potentially a file descriptor from the socket. + /// + /// On success, returns the number of bytes and an optional file descriptor. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket. + fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> { + let mut fd = [0]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + + let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?; + let file = if fd_count == 0 { + None + } else { + // Safe because the first fd from recv_with_fds is owned by us and valid because this + // branch was taken. + Some(unsafe { File::from_raw_fd(fd[0]) }) + }; + Ok((read_count, file)) + } + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `iovecs` - A list of iovec to receive data from the socket. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> { + raw_recvmsg(self.socket_fd(), iovecs, fds) + } +} + +impl ScmSocket for UnixDatagram { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +} + +impl ScmSocket for UnixStream { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +} + +/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for +/// the lifetime of this object. +/// +/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size +/// being accurate. +pub unsafe trait IntoIovec { + /// Gets the base pointer of this `iovec`. + fn as_ptr(&self) -> *const c_void; + + /// Gets the size in bytes of this `iovec`. + fn size(&self) -> usize; +} + +// Safe because this slice can not have another mutable reference and it's pointer and size are +// guaranteed to be valid. +unsafe impl<'a> IntoIovec for &'a [u8] { + // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480 + #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))] + fn as_ptr(&self) -> *const c_void { + self.as_ref().as_ptr() as *const c_void + } + + fn size(&self) -> usize { + self.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::Write; + use std::mem::size_of; + use std::os::raw::c_long; + use std::os::unix::net::UnixDatagram; + use std::slice::from_raw_parts; + + use libc::cmsghdr; + + use vmm_sys_util::eventfd::EventFd; + + #[test] + fn buffer_len() { + assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>()); + assert_eq!( + CMSG_SPACE!(1 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() + ); + if size_of::<RawFd>() == 4 { + assert_eq!( + CMSG_SPACE!(2 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() + ); + assert_eq!( + CMSG_SPACE!(3 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + assert_eq!( + CMSG_SPACE!(4 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + } else if size_of::<RawFd>() == 8 { + assert_eq!( + CMSG_SPACE!(2 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + assert_eq!( + CMSG_SPACE!(3 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 3 + ); + assert_eq!( + CMSG_SPACE!(4 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 4 + ); + } + } + + #[test] + fn send_recv_no_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let write_count = s1 + .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[]) + .expect("failed to send data"); + + assert_eq!(write_count, 6); + + let mut buf = [0u8; 6]; + let mut files = [0; 1]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + let (read_count, file_count) = s2 + .recv_with_fds(&mut iovecs[..], &mut files) + .expect("failed to recv data"); + + assert_eq!(read_count, 6); + assert_eq!(file_count, 0); + assert_eq!(buf, [1, 1, 2, 21, 34, 55]); + } + + #[test] + fn send_recv_only_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fd([].as_ref(), evt.as_raw_fd()) + .expect("failed to send fd"); + + assert_eq!(write_count, 0); + + let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); + + let mut file = file_opt.unwrap(); + + assert_eq!(read_count, 0); + assert!(file.as_raw_fd() >= 0); + assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); + assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); + assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); + + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + .expect("failed to write to sent fd"); + + assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + } + + #[test] + fn send_recv_with_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) + .expect("failed to send fd"); + + assert_eq!(write_count, 1); + + let mut files = [0; 2]; + let mut buf = [0u8]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + let (read_count, file_count) = s2 + .recv_with_fds(&mut iovecs[..], &mut files) + .expect("failed to recv fd"); + + assert_eq!(read_count, 1); + assert_eq!(buf[0], 237); + assert_eq!(file_count, 1); + assert!(files[0] >= 0); + assert_ne!(files[0], s1.as_raw_fd()); + assert_ne!(files[0], s2.as_raw_fd()); + assert_ne!(files[0], evt.as_raw_fd()); + + let mut file = unsafe { File::from_raw_fd(files[0]) }; + + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + .expect("failed to write to sent fd"); + + assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + } +} |