diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-03-11 19:33:55 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2020-06-15 09:36:15 +0000 |
commit | 0444d328a6d7198e59aca9f8dd2e2d91501f9bac (patch) | |
tree | 25453c681b3cd348535ed17a1d525ec910efd49f /devices/src/virtio | |
parent | 2507cc57bc0145eb57305e60f6a7c21f3b4c9192 (diff) | |
download | crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar.gz crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar.bz2 crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar.lz crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar.xz crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.tar.zst crosvm-0444d328a6d7198e59aca9f8dd2e2d91501f9bac.zip |
hacky construct Wl in external proc
Diffstat (limited to 'devices/src/virtio')
-rw-r--r-- | devices/src/virtio/mod.rs | 1 | ||||
-rw-r--r-- | devices/src/virtio/wl2.rs | 1663 |
2 files changed, 1664 insertions, 0 deletions
diff --git a/devices/src/virtio/mod.rs b/devices/src/virtio/mod.rs index 6c25ce1..e807970 100644 --- a/devices/src/virtio/mod.rs +++ b/devices/src/virtio/mod.rs @@ -23,6 +23,7 @@ mod virtio_device; mod virtio_pci_common_config; mod virtio_pci_device; mod wl; +pub mod wl2; use std::cmp; use std::convert::TryFrom; diff --git a/devices/src/virtio/wl2.rs b/devices/src/virtio/wl2.rs new file mode 100644 index 0000000..1429b27 --- /dev/null +++ b/devices/src/virtio/wl2.rs @@ -0,0 +1,1663 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! This module implements the virtio wayland used by the guest to access the host's wayland server. +//! +//! The virtio wayland protocol is done over two queues: `in` and `out`. The `in` queue is used for +//! sending commands to the guest that are generated by the host, usually messages from the wayland +//! server. The `out` queue is for commands from the guest, usually requests to allocate shared +//! memory, open a wayland server connection, or send data over an existing connection. +//! +//! Each `WlVfd` represents one virtual file descriptor created by either the guest or the host. +//! Virtual file descriptors contain actual file descriptors, either a shared memory file descriptor +//! or a unix domain socket to the wayland server. In the shared memory case, there is also an +//! associated slot that indicates which KVM memory slot the memory is installed into, as well as a +//! page frame number that the guest can access the memory from. +//! +//! The types starting with `Ctrl` are structures representing the virtio wayland protocol "on the +//! wire." They are decoded and executed in the `execute` function and encoded as some variant of +//! `WlResp` for responses. +//! +//! There is one `WlState` instance that contains every known vfd and the current state of `in` +//! queue. The `in` queue requires extra state to buffer messages to the guest in case the `in` +//! queue is already full. The `WlState` also has a control socket necessary to fulfill certain +//! requests, such as those registering guest memory. +//! +//! The `Worker` is responsible for the poll loop over all possible events, encoding/decoding from +//! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill +//! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket. + +use std::cell::RefCell; +use std::collections::btree_map::Entry; +use std::collections::{BTreeMap as Map, BTreeSet as Set, VecDeque}; +use std::convert::From; +use std::error::Error as StdError; +use std::fmt::{self, Display}; +use std::fs::File; +use std::io::{self, Read, Seek, SeekFrom, Write}; +use std::mem::size_of; +#[cfg(feature = "wl-dmabuf")] +use std::os::raw::{c_uint, c_ulonglong}; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::path::{Path, PathBuf}; +use std::rc::Rc; +use std::result; +use std::thread; +use std::time::Duration; + +#[cfg(feature = "wl-dmabuf")] +use libc::{dup, EBADF, EINVAL}; + +use data_model::VolatileMemoryError; +use data_model::*; + +use msg_socket::{MsgError, MsgReceiver, MsgSender}; +#[cfg(feature = "wl-dmabuf")] +use resources::GpuMemoryDesc; +#[cfg(feature = "wl-dmabuf")] +use sys_util::ioctl_iow_nr; +use sys_util::{ + error, pipe, round_up_to_page_size, warn, Error, EventFd, FileFlags, GuestMemory, + GuestMemoryError, PollContext, PollToken, Result, ScmSocket, SharedMemory, +}; + +#[cfg(feature = "wl-dmabuf")] +use sys_util::ioctl_with_ref; + +use super::resource_bridge::*; +use super::{ + DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_WL, VIRTIO_F_VERSION_1, +}; +use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse}; + +const VIRTWL_SEND_MAX_ALLOCS: usize = 28; +const VIRTIO_WL_CMD_VFD_NEW: u32 = 256; +const VIRTIO_WL_CMD_VFD_CLOSE: u32 = 257; +const VIRTIO_WL_CMD_VFD_SEND: u32 = 258; +const VIRTIO_WL_CMD_VFD_RECV: u32 = 259; +const VIRTIO_WL_CMD_VFD_NEW_CTX: u32 = 260; +const VIRTIO_WL_CMD_VFD_NEW_PIPE: u32 = 261; +const VIRTIO_WL_CMD_VFD_HUP: u32 = 262; +#[cfg(feature = "wl-dmabuf")] +const VIRTIO_WL_CMD_VFD_NEW_DMABUF: u32 = 263; +#[cfg(feature = "wl-dmabuf")] +const VIRTIO_WL_CMD_VFD_DMABUF_SYNC: u32 = 264; +#[cfg(feature = "gpu")] +const VIRTIO_WL_CMD_VFD_SEND_FOREIGN_ID: u32 = 265; +const VIRTIO_WL_CMD_VFD_NEW_CTX_NAMED: u32 = 266; +const VIRTIO_WL_RESP_OK: u32 = 4096; +const VIRTIO_WL_RESP_VFD_NEW: u32 = 4097; +#[cfg(feature = "wl-dmabuf")] +const VIRTIO_WL_RESP_VFD_NEW_DMABUF: u32 = 4098; +const VIRTIO_WL_RESP_ERR: u32 = 4352; +const VIRTIO_WL_RESP_OUT_OF_MEMORY: u32 = 4353; +const VIRTIO_WL_RESP_INVALID_ID: u32 = 4354; +const VIRTIO_WL_RESP_INVALID_TYPE: u32 = 4355; +const VIRTIO_WL_RESP_INVALID_FLAGS: u32 = 4356; +const VIRTIO_WL_RESP_INVALID_CMD: u32 = 4357; +const VIRTIO_WL_VFD_WRITE: u32 = 0x1; +const VIRTIO_WL_VFD_READ: u32 = 0x2; +const VIRTIO_WL_VFD_MAP: u32 = 0x2; +const VIRTIO_WL_VFD_CONTROL: u32 = 0x4; +const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01; + +const QUEUE_SIZE: u16 = 16; +const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE]; + +const NEXT_VFD_ID_BASE: u32 = 0x40000000; +const VFD_ID_HOST_MASK: u32 = NEXT_VFD_ID_BASE; +// Each in-vq buffer is one page, so we need to leave space for the control header and the maximum +// number of allocs. +const IN_BUFFER_LEN: usize = + 0x1000 - size_of::<CtrlVfdRecv>() - VIRTWL_SEND_MAX_ALLOCS * size_of::<Le32>(); + +#[cfg(feature = "wl-dmabuf")] +const VIRTIO_WL_VFD_DMABUF_SYNC_VALID_FLAG_MASK: u32 = 0x7; + +#[cfg(feature = "wl-dmabuf")] +const DMA_BUF_IOCTL_BASE: c_uint = 0x62; + +#[cfg(feature = "wl-dmabuf")] +#[repr(C)] +#[derive(Copy, Clone)] +struct dma_buf_sync { + flags: c_ulonglong, +} + +#[cfg(feature = "wl-dmabuf")] +ioctl_iow_nr!(DMA_BUF_IOCTL_SYNC, DMA_BUF_IOCTL_BASE, 0, dma_buf_sync); + +const VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL: u32 = 0; +const VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU: u32 = 1; + +fn encode_vfd_new( + writer: &mut Writer, + resp: bool, + vfd_id: u32, + flags: u32, + pfn: u64, + size: u32, +) -> WlResult<()> { + let ctrl_vfd_new = CtrlVfdNew { + hdr: CtrlHeader { + type_: Le32::from(if resp { + VIRTIO_WL_RESP_VFD_NEW + } else { + VIRTIO_WL_CMD_VFD_NEW + }), + flags: Le32::from(0), + }, + id: Le32::from(vfd_id), + flags: Le32::from(flags), + pfn: Le64::from(pfn), + size: Le32::from(size), + }; + + writer + .write_obj(ctrl_vfd_new) + .map_err(WlError::WriteResponse) +} + +#[cfg(feature = "wl-dmabuf")] +fn encode_vfd_new_dmabuf( + writer: &mut Writer, + vfd_id: u32, + flags: u32, + pfn: u64, + size: u32, + desc: GpuMemoryDesc, +) -> WlResult<()> { + let ctrl_vfd_new_dmabuf = CtrlVfdNewDmabuf { + hdr: CtrlHeader { + type_: Le32::from(VIRTIO_WL_RESP_VFD_NEW_DMABUF), + flags: Le32::from(0), + }, + id: Le32::from(vfd_id), + flags: Le32::from(flags), + pfn: Le64::from(pfn), + size: Le32::from(size), + width: Le32::from(0), + height: Le32::from(0), + format: Le32::from(0), + stride0: Le32::from(desc.planes[0].stride), + stride1: Le32::from(desc.planes[1].stride), + stride2: Le32::from(desc.planes[2].stride), + offset0: Le32::from(desc.planes[0].offset), + offset1: Le32::from(desc.planes[1].offset), + offset2: Le32::from(desc.planes[2].offset), + }; + + writer + .write_obj(ctrl_vfd_new_dmabuf) + .map_err(WlError::WriteResponse) +} + +fn encode_vfd_recv(writer: &mut Writer, vfd_id: u32, data: &[u8], vfd_ids: &[u32]) -> WlResult<()> { + let ctrl_vfd_recv = CtrlVfdRecv { + hdr: CtrlHeader { + type_: Le32::from(VIRTIO_WL_CMD_VFD_RECV), + flags: Le32::from(0), + }, + id: Le32::from(vfd_id), + vfd_count: Le32::from(vfd_ids.len() as u32), + }; + writer + .write_obj(ctrl_vfd_recv) + .map_err(WlError::WriteResponse)?; + + for &recv_vfd_id in vfd_ids.iter() { + writer + .write_obj(Le32::from(recv_vfd_id)) + .map_err(WlError::WriteResponse)?; + } + + writer.write_all(data).map_err(WlError::WriteResponse) +} + +fn encode_vfd_hup(writer: &mut Writer, vfd_id: u32) -> WlResult<()> { + let ctrl_vfd_new = CtrlVfd { + hdr: CtrlHeader { + type_: Le32::from(VIRTIO_WL_CMD_VFD_HUP), + flags: Le32::from(0), + }, + id: Le32::from(vfd_id), + }; + + writer + .write_obj(ctrl_vfd_new) + .map_err(WlError::WriteResponse) +} + +fn encode_resp(writer: &mut Writer, resp: WlResp) -> WlResult<()> { + match resp { + WlResp::VfdNew { + id, + flags, + pfn, + size, + resp, + } => encode_vfd_new(writer, resp, id, flags, pfn, size), + #[cfg(feature = "wl-dmabuf")] + WlResp::VfdNewDmabuf { + id, + flags, + pfn, + size, + desc, + } => encode_vfd_new_dmabuf(writer, id, flags, pfn, size, desc), + WlResp::VfdRecv { id, data, vfds } => encode_vfd_recv(writer, id, data, vfds), + WlResp::VfdHup { id } => encode_vfd_hup(writer, id), + r => writer + .write_obj(Le32::from(r.get_code())) + .map_err(WlError::WriteResponse), + } +} + +#[allow(dead_code)] +#[derive(Debug)] +enum WlError { + NewAlloc(Error), + NewPipe(Error), + AllocSetSize(Error), + SocketConnect(io::Error), + SocketNonBlock(io::Error), + VmControl(MsgError), + VmBadResponse, + CheckedOffset, + ParseDesc(io::Error), + GuestMemory(GuestMemoryError), + VolatileMemory(VolatileMemoryError), + SendVfd(Error), + WritePipe(io::Error), + RecvVfd(Error), + ReadPipe(io::Error), + PollContextAdd(Error), + DmabufSync(io::Error), + WriteResponse(io::Error), + InvalidString(std::str::Utf8Error), + UnknownSocketName(String), +} + +impl Display for WlError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::WlError::*; + + match self { + NewAlloc(e) => write!(f, "failed to create shared memory allocation: {}", e), + NewPipe(e) => write!(f, "failed to create pipe: {}", e), + AllocSetSize(e) => write!(f, "failed to set size of shared memory: {}", e), + SocketConnect(e) => write!(f, "failed to connect socket: {}", e), + SocketNonBlock(e) => write!(f, "failed to set socket as non-blocking: {}", e), + VmControl(e) => write!(f, "failed to control parent VM: {}", e), + VmBadResponse => write!(f, "invalid response from parent VM"), + CheckedOffset => write!(f, "overflow in calculation"), + ParseDesc(e) => write!(f, "error parsing descriptor: {}", e), + GuestMemory(e) => write!(f, "access violation in guest memory: {}", e), + VolatileMemory(e) => write!(f, "access violating in guest volatile memory: {}", e), + SendVfd(e) => write!(f, "failed to send on a socket: {}", e), + WritePipe(e) => write!(f, "failed to write to a pipe: {}", e), + RecvVfd(e) => write!(f, "failed to recv on a socket: {}", e), + ReadPipe(e) => write!(f, "failed to read a pipe: {}", e), + PollContextAdd(e) => write!(f, "failed to listen to FD on poll context: {}", e), + DmabufSync(e) => write!(f, "failed to synchronize DMABuf access: {}", e), + WriteResponse(e) => write!(f, "failed to write response: {}", e), + InvalidString(e) => write!(f, "invalid string: {}", e), + UnknownSocketName(name) => write!(f, "unknown socket name: {}", name), + } + } +} + +impl std::error::Error for WlError {} + +type WlResult<T> = result::Result<T, WlError>; + +impl From<GuestMemoryError> for WlError { + fn from(e: GuestMemoryError) -> WlError { + WlError::GuestMemory(e) + } +} + +impl From<VolatileMemoryError> for WlError { + fn from(e: VolatileMemoryError) -> WlError { + WlError::VolatileMemory(e) + } +} + +#[derive(Clone)] +struct VmRequester { + inner: Rc<RefCell<VmMemoryControlRequestSocket>>, +} + +impl VmRequester { + fn new(vm_socket: VmMemoryControlRequestSocket) -> VmRequester { + VmRequester { + inner: Rc::new(RefCell::new(vm_socket)), + } + } + + fn request(&self, request: VmMemoryRequest) -> WlResult<VmMemoryResponse> { + let mut inner = self.inner.borrow_mut(); + let vm_socket = &mut *inner; + vm_socket.send(&request).map_err(WlError::VmControl)?; + vm_socket.recv().map_err(WlError::VmControl) + } +} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlHeader { + type_: Le32, + flags: Le32, +} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlVfdNew { + hdr: CtrlHeader, + id: Le32, + flags: Le32, + pfn: Le64, + size: Le32, +} + +unsafe impl DataInit for CtrlVfdNew {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlVfdNewCtxNamed { + hdr: CtrlHeader, + id: Le32, + flags: Le32, // Ignored. + pfn: Le64, // Ignored. + size: Le32, // Ignored. + name: [u8; 32], +} + +unsafe impl DataInit for CtrlVfdNewCtxNamed {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +#[cfg(feature = "wl-dmabuf")] +struct CtrlVfdNewDmabuf { + hdr: CtrlHeader, + id: Le32, + flags: Le32, + pfn: Le64, + size: Le32, + width: Le32, + height: Le32, + format: Le32, + stride0: Le32, + stride1: Le32, + stride2: Le32, + offset0: Le32, + offset1: Le32, + offset2: Le32, +} + +#[cfg(feature = "wl-dmabuf")] +unsafe impl DataInit for CtrlVfdNewDmabuf {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +#[cfg(feature = "wl-dmabuf")] +struct CtrlVfdDmabufSync { + hdr: CtrlHeader, + id: Le32, + flags: Le32, +} + +#[cfg(feature = "wl-dmabuf")] +unsafe impl DataInit for CtrlVfdDmabufSync {} + +#[repr(C)] +#[derive(Copy, Clone)] +struct CtrlVfdRecv { + hdr: CtrlHeader, + id: Le32, + vfd_count: Le32, +} + +unsafe impl DataInit for CtrlVfdRecv {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlVfd { + hdr: CtrlHeader, + id: Le32, +} + +unsafe impl DataInit for CtrlVfd {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlVfdSend { + hdr: CtrlHeader, + id: Le32, + vfd_count: Le32, + // Remainder is an array of vfd_count IDs followed by data. +} + +unsafe impl DataInit for CtrlVfdSend {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct CtrlVfdSendVfd { + kind: Le32, + id: Le32, +} + +unsafe impl DataInit for CtrlVfdSendVfd {} + +#[derive(Debug)] +#[allow(dead_code)] +enum WlResp<'a> { + Ok, + VfdNew { + id: u32, + flags: u32, + pfn: u64, + size: u32, + // The VfdNew variant can be either a response or a command depending on this `resp`. This + // is important for the `get_code` method. + resp: bool, + }, + #[cfg(feature = "wl-dmabuf")] + VfdNewDmabuf { + id: u32, + flags: u32, + pfn: u64, + size: u32, + desc: GpuMemoryDesc, + }, + VfdRecv { + id: u32, + data: &'a [u8], + vfds: &'a [u32], + }, + VfdHup { + id: u32, + }, + Err(Box<dyn StdError>), + OutOfMemory, + InvalidId, + InvalidType, + InvalidFlags, + InvalidCommand, +} + +impl<'a> WlResp<'a> { + fn get_code(&self) -> u32 { + match *self { + WlResp::Ok => VIRTIO_WL_RESP_OK, + WlResp::VfdNew { resp, .. } => { + if resp { + VIRTIO_WL_RESP_VFD_NEW + } else { + VIRTIO_WL_CMD_VFD_NEW + } + } + #[cfg(feature = "wl-dmabuf")] + WlResp::VfdNewDmabuf { .. } => VIRTIO_WL_RESP_VFD_NEW_DMABUF, + WlResp::VfdRecv { .. } => VIRTIO_WL_CMD_VFD_RECV, + WlResp::VfdHup { .. } => VIRTIO_WL_CMD_VFD_HUP, + WlResp::Err(_) => VIRTIO_WL_RESP_ERR, + WlResp::OutOfMemory => VIRTIO_WL_RESP_OUT_OF_MEMORY, + WlResp::InvalidId => VIRTIO_WL_RESP_INVALID_ID, + WlResp::InvalidType => VIRTIO_WL_RESP_INVALID_TYPE, + WlResp::InvalidFlags => VIRTIO_WL_RESP_INVALID_FLAGS, + WlResp::InvalidCommand => VIRTIO_WL_RESP_INVALID_CMD, + } + } +} + +#[derive(Default)] +struct WlVfd { + socket: Option<UnixStream>, + guest_shared_memory: Option<(u64 /* size */, File)>, + remote_pipe: Option<File>, + local_pipe: Option<(u32 /* flags */, File)>, + slot: Option<(u32 /* slot */, u64 /* pfn */, VmRequester)>, + #[cfg(feature = "wl-dmabuf")] + is_dmabuf: bool, +} + +impl fmt::Debug for WlVfd { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "WlVfd {{")?; + if let Some(s) = &self.socket { + write!(f, " socket: {}", s.as_raw_fd())?; + } + if let Some((slot, pfn, _)) = &self.slot { + write!(f, " slot: {} pfn: {}", slot, pfn)?; + } + if let Some(s) = &self.remote_pipe { + write!(f, " remote: {}", s.as_raw_fd())?; + } + if let Some((_, s)) = &self.local_pipe { + write!(f, " local: {}", s.as_raw_fd())?; + } + write!(f, " }}") + } +} + +impl WlVfd { + fn connect<P: AsRef<Path>>(path: P) -> WlResult<WlVfd> { + let socket = UnixStream::connect(path).map_err(WlError::SocketConnect)?; + socket + .set_nonblocking(true) + .map_err(WlError::SocketNonBlock)?; + let mut vfd = WlVfd::default(); + vfd.socket = Some(socket); + Ok(vfd) + } + + fn allocate(vm: VmRequester, size: u64) -> WlResult<WlVfd> { + let size_page_aligned = round_up_to_page_size(size as usize) as u64; + let mut vfd_shm = SharedMemory::named("virtwl_alloc").map_err(WlError::NewAlloc)?; + vfd_shm + .set_size(size_page_aligned) + .map_err(WlError::AllocSetSize)?; + let register_response = vm.request(VmMemoryRequest::RegisterMemory( + MaybeOwnedFd::Borrowed(vfd_shm.as_raw_fd()), + vfd_shm.size() as usize, + ))?; + match register_response { + VmMemoryResponse::RegisterMemory { pfn, slot } => { + let mut vfd = WlVfd::default(); + vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm.into())); + vfd.slot = Some((slot, pfn, vm)); + Ok(vfd) + } + _ => Err(WlError::VmBadResponse), + } + } + + #[cfg(feature = "wl-dmabuf")] + fn dmabuf( + vm: VmRequester, + width: u32, + height: u32, + format: u32, + ) -> WlResult<(WlVfd, GpuMemoryDesc)> { + let allocate_and_register_gpu_memory_response = + vm.request(VmMemoryRequest::AllocateAndRegisterGpuMemory { + width, + height, + format, + })?; + match allocate_and_register_gpu_memory_response { + VmMemoryResponse::AllocateAndRegisterGpuMemory { + fd, + pfn, + slot, + desc, + } => { + let mut vfd = WlVfd::default(); + // Duplicate FD for shared memory instance. + let raw_fd = unsafe { File::from_raw_fd(dup(fd.as_raw_fd())) }; + let vfd_shm = SharedMemory::from_raw_fd(raw_fd).map_err(WlError::NewAlloc)?; + vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm.into())); + vfd.slot = Some((slot, pfn, vm)); + vfd.is_dmabuf = true; + Ok((vfd, desc)) + } + _ => Err(WlError::VmBadResponse), + } + } + + #[cfg(feature = "wl-dmabuf")] + fn dmabuf_sync(&self, flags: u32) -> WlResult<()> { + if !self.is_dmabuf { + return Err(WlError::DmabufSync(io::Error::from_raw_os_error(EINVAL))); + } + + match &self.guest_shared_memory { + Some((_, fd)) => { + let sync = dma_buf_sync { + flags: flags as u64, + }; + // Safe as fd is a valid dmabuf and incorrect flags will return an error. + if unsafe { ioctl_with_ref(fd, DMA_BUF_IOCTL_SYNC(), &sync) } < 0 { + Err(WlError::DmabufSync(io::Error::last_os_error())) + } else { + Ok(()) + } + } + None => Err(WlError::DmabufSync(io::Error::from_raw_os_error(EBADF))), + } + } + + fn pipe_remote_read_local_write() -> WlResult<WlVfd> { + let (read_pipe, write_pipe) = pipe(true).map_err(WlError::NewPipe)?; + let mut vfd = WlVfd::default(); + vfd.remote_pipe = Some(read_pipe); + vfd.local_pipe = Some((VIRTIO_WL_VFD_WRITE, write_pipe)); + Ok(vfd) + } + + fn pipe_remote_write_local_read() -> WlResult<WlVfd> { + let (read_pipe, write_pipe) = pipe(true).map_err(WlError::NewPipe)?; + let mut vfd = WlVfd::default(); + vfd.remote_pipe = Some(write_pipe); + vfd.local_pipe = Some((VIRTIO_WL_VFD_READ, read_pipe)); + Ok(vfd) + } + + fn from_file(vm: VmRequester, mut fd: File) -> WlResult<WlVfd> { + // We need to determine if the given file is more like shared memory or a pipe/socket. A + // quick and easy check is to seek to the end of the file. If it works we assume it's not a + // pipe/socket because those have no end. We can even use that seek location as an indicator + // for how big the shared memory chunk to map into guest memory is. If seeking to the end + // fails, we assume it's a socket or pipe with read/write semantics. + match fd.seek(SeekFrom::End(0)) { + Ok(fd_size) => { + let size = round_up_to_page_size(fd_size as usize) as u64; + let register_response = vm.request(VmMemoryRequest::RegisterMemory( + MaybeOwnedFd::Borrowed(fd.as_raw_fd()), + size as usize, + ))?; + + match register_response { + VmMemoryResponse::RegisterMemory { pfn, slot } => { + let mut vfd = WlVfd::default(); + vfd.guest_shared_memory = Some((size, fd)); + vfd.slot = Some((slot, pfn, vm)); + Ok(vfd) + } + _ => Err(WlError::VmBadResponse), + } + } + _ => { + let flags = match FileFlags::from_file(&fd) { + Ok(FileFlags::Read) => VIRTIO_WL_VFD_READ, + Ok(FileFlags::Write) => VIRTIO_WL_VFD_WRITE, + Ok(FileFlags::ReadWrite) => VIRTIO_WL_VFD_READ | VIRTIO_WL_VFD_WRITE, + _ => 0, + }; + let mut vfd = WlVfd::default(); + vfd.local_pipe = Some((flags, fd)); + Ok(vfd) + } + } + } + + fn flags(&self, use_transition_flags: bool) -> u32 { + let mut flags = 0; + if use_transition_flags { + if self.socket.is_some() { + flags |= VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ; + } + if let Some((f, _)) = self.local_pipe { + flags |= f; + } + } else { + if self.socket.is_some() { + flags |= VIRTIO_WL_VFD_CONTROL; + } + if self.slot.is_some() { + flags |= VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_MAP + } + } + flags + } + + // Page frame number in the guest this VFD was mapped at. + fn pfn(&self) -> Option<u64> { + self.slot.as_ref().map(|s| s.1) + } + + // Size in bytes of the shared memory VFD. + fn size(&self) -> Option<u64> { + self.guest_shared_memory.as_ref().map(|&(size, _)| size) + } + + // The FD that gets sent if this VFD is sent over a socket. + fn send_fd(&self) -> Option<RawFd> { + self.guest_shared_memory + .as_ref() + .map(|(_, fd)| fd.as_raw_fd()) + .or(self.socket.as_ref().map(|s| s.as_raw_fd())) + .or(self.remote_pipe.as_ref().map(|p| p.as_raw_fd())) + } + + // The FD that is used for polling for events on this VFD. + fn poll_fd(&self) -> Option<&dyn AsRawFd> { + self.socket + .as_ref() + .map(|s| s as &dyn AsRawFd) + .or(self.local_pipe.as_ref().map(|(_, p)| p as &dyn AsRawFd)) + } + + // Sends data/files from the guest to the host over this VFD. + fn send(&mut self, fds: &[RawFd], data: &mut Reader) -> WlResult<WlResp> { + if let Some(socket) = &self.socket { + socket + .send_with_fds(data.get_remaining(), fds) + .map_err(WlError::SendVfd)?; + Ok(WlResp::Ok) + } else if let Some((_, local_pipe)) = &mut self.local_pipe { + // Impossible to send fds over a simple pipe. + if !fds.is_empty() { + return Ok(WlResp::InvalidType); + } + data.read_to(local_pipe, usize::max_value()) + .map_err(WlError::WritePipe)?; + Ok(WlResp::Ok) + } else { + Ok(WlResp::InvalidType) + } + } + + // Receives data/files from the host for this VFD and queues it for the guest. + fn recv(&mut self, in_file_queue: &mut Vec<File>) -> WlResult<Vec<u8>> { + if let Some(socket) = self.socket.take() { + let mut buf = vec![0; IN_BUFFER_LEN]; + let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS]; + // If any errors happen, the socket will get dropped, preventing more reading. + let (len, file_count) = socket + .recv_with_fds(&mut buf[..], &mut fd_buf) + .map_err(WlError::RecvVfd)?; + // If any data gets read, the put the socket back for future recv operations. + if len != 0 || file_count != 0 { + buf.truncate(len); + buf.shrink_to_fit(); + self.socket = Some(socket); + // Safe because the first file_counts fds from recv_with_fds are owned by us and + // valid. + in_file_queue.extend( + fd_buf[..file_count] + .iter() + .map(|&fd| unsafe { File::from_raw_fd(fd) }), + ); + return Ok(buf); + } + Ok(Vec::new()) + } else if let Some((flags, mut local_pipe)) = self.local_pipe.take() { + let mut buf = Vec::new(); + buf.resize(IN_BUFFER_LEN, 0); + let len = local_pipe.read(&mut buf[..]).map_err(WlError::ReadPipe)?; + if len != 0 { + buf.truncate(len); + buf.shrink_to_fit(); + self.local_pipe = Some((flags, local_pipe)); + return Ok(buf); + } + Ok(Vec::new()) + } else { + Ok(Vec::new()) + } + } + + // Called after this VFD is sent over a socket to ensure the local end of the VFD receives hang + // up events. + fn close_remote(&mut self) { + self.remote_pipe = None; + } + + fn close(&mut self) -> WlResult<()> { + if let Some((slot, _, vm)) = self.slot.take() { + vm.request(VmMemoryRequest::UnregisterMemory(slot))?; + } + self.socket = None; + self.remote_pipe = None; + self.local_pipe = None; + Ok(()) + } +} + +impl Drop for WlVfd { + fn drop(&mut self) { + let _ = self.close(); + } +} + +#[derive(Debug)] +enum WlRecv { + Vfd { id: u32 }, + Data { buf: Vec<u8> }, + Hup, +} + +struct WlState { + wayland_paths: Map<String, PathBuf>, + vm: VmRequester, + resource_bridge: Option<ResourceRequestSocket>, + use_transition_flags: bool, + poll_ctx: PollContext<u32>, + vfds: Map<u32, WlVfd>, + next_vfd_id: u32, + in_file_queue: Vec<File>, + in_queue: VecDeque<(u32 /* vfd_id */, WlRecv)>, + current_recv_vfd: Option<u32>, + recv_vfds: Vec<u32>, +} + +impl WlState { + fn new( + wayland_paths: Map<String, PathBuf>, + vm_socket: VmMemoryControlRequestSocket, + use_transition_flags: bool, + resource_bridge: Option<ResourceRequestSocket>, + ) -> WlState { + WlState { + wayland_paths, + vm: VmRequester::new(vm_socket), + resource_bridge, + poll_ctx: PollContext::new().expect("failed to create PollContext"), + use_transition_flags, + vfds: Map::new(), + next_vfd_id: NEXT_VFD_ID_BASE, + in_file_queue: Vec::new(), + in_queue: VecDeque::new(), + current_recv_vfd: None, + recv_vfds: Vec::new(), + } + } + + fn new_pipe(&mut self, id: u32, flags: u32) -> WlResult<WlResp> { + if id & VFD_ID_HOST_MASK != 0 { + return Ok(WlResp::InvalidId); + } + + if flags & !(VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ) != 0 { + return Ok(WlResp::InvalidFlags); + } + + if flags & VIRTIO_WL_VFD_WRITE != 0 && flags & VIRTIO_WL_VFD_READ != 0 { + return Ok(WlResp::InvalidFlags); + } + + match self.vfds.entry(id) { + Entry::Vacant(entry) => { + let vfd = if flags & VIRTIO_WL_VFD_WRITE != 0 { + WlVfd::pipe_remote_read_local_write()? + } else if flags & VIRTIO_WL_VFD_READ != 0 { + WlVfd::pipe_remote_write_local_read()? + } else { + return Ok(WlResp::InvalidFlags); + }; + self.poll_ctx + .add(vfd.poll_fd().unwrap(), id) + .map_err(WlError::PollContextAdd)?; + let resp = WlResp::VfdNew { + id, + flags: 0, + pfn: 0, + size: 0, + resp: true, + }; + entry.insert(vfd); + Ok(resp) + } + Entry::Occupied(_) => Ok(WlResp::InvalidId), + } + } + + fn new_alloc(&mut self, id: u32, flags: u32, size: u32) -> WlResult<WlResp> { + if id & VFD_ID_HOST_MASK != 0 { + return Ok(WlResp::InvalidId); + } + + if self.use_transition_flags { + if flags != 0 { + return Ok(WlResp::InvalidFlags); + } + } else if flags & !(VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_MAP) != 0 { + return Ok(WlResp::Err(Box::from("invalid flags"))); + } + + match self.vfds.entry(id) { + Entry::Vacant(entry) => { + let vfd = WlVfd::allocate(self.vm.clone(), size as u64)?; + let resp = WlResp::VfdNew { + id, + flags, + pfn: vfd.pfn().unwrap_or_default(), + size: vfd.size().unwrap_or_default() as u32, + resp: true, + }; + entry.insert(vfd); + Ok(resp) + } + Entry::Occupied(_) => Ok(WlResp::InvalidId), + } + } + + #[cfg(feature = "wl-dmabuf")] + fn new_dmabuf(&mut self, id: u32, width: u32, height: u32, format: u32) -> WlResult<WlResp> { + if id & VFD_ID_HOST_MASK != 0 { + return Ok(WlResp::InvalidId); + } + + match self.vfds.entry(id) { + Entry::Vacant(entry) => { + let (vfd, desc) = WlVfd::dmabuf(self.vm.clone(), width, height, format)?; + let resp = WlResp::VfdNewDmabuf { + id, + flags: 0, + pfn: vfd.pfn().unwrap_or_default(), + size: vfd.size().unwrap_or_default() as u32, + desc, + }; + entry.insert(vfd); + Ok(resp) + } + Entry::Occupied(_) => Ok(WlResp::InvalidId), + } + } + + #[cfg(feature = "wl-dmabuf")] + fn dmabuf_sync(&mut self, vfd_id: u32, flags: u32) -> WlResult<WlResp> { + if flags & !(VIRTIO_WL_VFD_DMABUF_SYNC_VALID_FLAG_MASK) != 0 { + return Ok(WlResp::InvalidFlags); + } + + match self.vfds.get_mut(&vfd_id) { + Some(vfd) => { + vfd.dmabuf_sync(flags)?; + Ok(WlResp::Ok) + } + None => Ok(WlResp::InvalidId), + } + } + + fn new_context(&mut self, id: u32, name: &str) -> WlResult<WlResp> { + if id & VFD_ID_HOST_MASK != 0 { + return Ok(WlResp::InvalidId); + } + + let flags = if self.use_transition_flags { + VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ + } else { + VIRTIO_WL_VFD_CONTROL + }; + + match self.vfds.entry(id) { + Entry::Vacant(entry) => { + let vfd = entry.insert(WlVfd::connect( + &self + .wayland_paths + .get(name) + .ok_or(WlError::UnknownSocketName(name.to_string()))?, + )?); + self.poll_ctx + .add(vfd.poll_fd().unwrap(), id) + .map_err(WlError::PollContextAdd)?; + Ok(WlResp::VfdNew { + id, + flags, + pfn: 0, + size: 0, + resp: true, + }) + } + Entry::Occupied(_) => Ok(WlResp::InvalidId), + } + } + + fn process_poll_context(&mut self) { + let events = match self.poll_ctx.wait_timeout(Duration::from_secs(0)) { + Ok(v) => v.to_owned(), + Err(e) => { + error!("failed polling for vfd evens: {}", e); + return; + } + }; + + for event in events.as_ref().iter_readable() { + if let Err(e) = self.recv(event.token()) { + error!("failed to recv from vfd: {}", e) + } + } + + for event in events.as_ref().iter_hungup() { + if !event.readable() { + let vfd_id = event.token(); + if let Some(fd) = self.vfds.get(&vfd_id).and_then(|vfd| vfd.poll_fd()) { + if let Err(e) = self.poll_ctx.delete(fd) { + warn!("failed to remove hungup vfd from poll context: {}", e); + } + } + self.in_queue.push_back((vfd_id, WlRecv::Hup)); + } + } + } + + fn close(&mut self, vfd_id: u32) -> WlResult<WlResp> { + let mut to_delete = Set::new(); + for (dest_vfd_id, q) in &self.in_queue { + if *dest_vfd_id == vfd_id { + if let WlRecv::Vfd { id } = q { + to_delete.insert(*id); + } + } + } + for vfd_id in to_delete { + // Sorry sub-error, we can't have cascading errors leaving us in an inconsistent state. + let _ = self.close(vfd_id); + } + match self.vfds.remove(&vfd_id) { + Some(mut vfd) => { + self.in_queue.retain(|&(id, _)| id != vfd_id); + vfd.close()?; + Ok(WlResp::Ok) + } + None => Ok(WlResp::InvalidId), + } + } + + fn send( + &mut self, + vfd_id: u32, + vfd_count: usize, + foreign_id: bool, + reader: &mut Reader, + ) -> WlResult<WlResp> { + // First stage gathers and normalizes all id information from guest memory. + let mut send_vfd_ids = [CtrlVfdSendVfd::default(); VIRTWL_SEND_MAX_ALLOCS]; + for vfd_id in send_vfd_ids.iter_mut().take(vfd_count) { + *vfd_id = if foreign_id { + reader.read_obj().map_err(WlError::ParseDesc)? + } else { + CtrlVfdSendVfd { + kind: Le32::from(VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL), + id: reader.read_obj().map_err(WlError::ParseDesc)?, + } + } + } + + // Next stage collects corresponding file descriptors for each id. + let mut fds = [0; VIRTWL_SEND_MAX_ALLOCS]; + #[cfg(feature = "gpu")] + let mut bridged_files = Vec::new(); + for (&send_vfd_id, fd) in send_vfd_ids[..vfd_count].iter().zip(fds.iter_mut()) { + let id = send_vfd_id.id.to_native(); + match send_vfd_id.kind.to_native() { + VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL => match self.vfds.get(&id) { + Some(vfd) => match vfd.send_fd() { + Some(vfd_fd) => *fd = vfd_fd, + None => return Ok(WlResp::InvalidType), + }, + None => { + warn!("attempt to send non-existant vfd 0x{:08x}", id); + return Ok(WlResp::InvalidId); + } + }, + #[cfg(feature = "gpu")] + VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU if self.resource_bridge.is_some() => { + let sock = self.resource_bridge.as_ref().unwrap(); + match get_resource_info(sock, id) { + Ok(info) => { + *fd = info.file.as_raw_fd(); + bridged_files.push(info.file); + } + Err(ResourceBridgeError::InvalidResource(id)) => { + warn!("attempt to send non-existent gpu resource {}", id); + return Ok(WlResp::InvalidId); + } + Err(e) => { + error!("{}", e); + // If there was an error with the resource bridge, it can no longer be + // trusted to continue to function. + self.resource_bridge = None; + return Ok(WlResp::InvalidId); + } + } + } + VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU => { + let _ = self.resource_bridge.as_ref(); + warn!("attempt to send foreign resource kind but feature is disabled"); + } + kind => { + warn!( + "attempt to send unknown foreign resource kind: {} id: {:08x}", + kind, id + ); + return Ok(WlResp::InvalidId); + } + } + } + + // Final stage sends file descriptors and data to the target vfd's socket. + match self.vfds.get_mut(&vfd_id) { + Some(vfd) => match vfd.send(&fds[..vfd_count], reader)? { + WlResp::Ok => {} + _ => return Ok(WlResp::InvalidType), + }, + None => return Ok(WlResp::InvalidId), + } + // The vfds with remote FDs need to be closed so that the local side can receive + // hangup events. + for &send_vfd_id in &send_vfd_ids[..vfd_count] { + if send_vfd_id.kind == VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL { + if let Some(vfd) = self.vfds.get_mut(&send_vfd_id.id.into()) { + vfd.close_remote(); + } + } + } + Ok(WlResp::Ok) + } + + fn recv(&mut self, vfd_id: u32) -> WlResult<()> { + let buf = match self.vfds.get_mut(&vfd_id) { + Some(vfd) => vfd.recv(&mut self.in_file_queue)?, + None => return Ok(()), + }; + if self.in_file_queue.is_empty() && buf.is_empty() { + self.in_queue.push_back((vfd_id, WlRecv::Hup)); + return Ok(()); + } + for file in self.in_file_queue.drain(..) { + let vfd = WlVfd::from_file(self.vm.clone(), file)?; + if let Some(poll_fd) = vfd.poll_fd() { + self.poll_ctx + .add(poll_fd, self.next_vfd_id) + .map_err(WlError::PollContextAdd)?; + } + self.vfds.insert(self.next_vfd_id, vfd); + self.in_queue.push_back(( + vfd_id, + WlRecv::Vfd { + id: self.next_vfd_id, + }, + )); + self.next_vfd_id += 1; + } + self.in_queue.push_back((vfd_id, WlRecv::Data { buf })); + + Ok(()) + } + + fn execute(&mut self, reader: &mut Reader) -> WlResult<WlResp> { + let type_ = { + let mut type_reader = reader.clone(); + type_reader.read_obj::<Le32>().map_err(WlError::ParseDesc)? + }; + match type_.into() { + VIRTIO_WL_CMD_VFD_NEW => { + let ctrl = reader + .read_obj::<CtrlVfdNew>() + .map_err(WlError::ParseDesc)?; + self.new_alloc(ctrl.id.into(), ctrl.flags.into(), ctrl.size.into()) + } + VIRTIO_WL_CMD_VFD_CLOSE => { + let ctrl = reader.read_obj::<CtrlVfd>().map_err(WlError::ParseDesc)?; + self.close(ctrl.id.into()) + } + VIRTIO_WL_CMD_VFD_SEND => { + let ctrl = reader + .read_obj::<CtrlVfdSend>() + .map_err(WlError::ParseDesc)?; + let foreign_id = false; + self.send( + ctrl.id.into(), + ctrl.vfd_count.to_native() as usize, + foreign_id, + reader, + ) + } + #[cfg(feature = "gpu")] + VIRTIO_WL_CMD_VFD_SEND_FOREIGN_ID => { + let ctrl = reader + .read_obj::<CtrlVfdSend>() + .map_err(WlError::ParseDesc)?; + let foreign_id = true; + self.send( + ctrl.id.into(), + ctrl.vfd_count.to_native() as usize, + foreign_id, + reader, + ) + } + VIRTIO_WL_CMD_VFD_NEW_CTX => { + let ctrl = reader.read_obj::<CtrlVfd>().map_err(WlError::ParseDesc)?; + self.new_context(ctrl.id.into(), "") + } + VIRTIO_WL_CMD_VFD_NEW_PIPE => { + let ctrl = reader + .read_obj::<CtrlVfdNew>() + .map_err(WlError::ParseDesc)?; + self.new_pipe(ctrl.id.into(), ctrl.flags.into()) + } + #[cfg(feature = "wl-dmabuf")] + VIRTIO_WL_CMD_VFD_NEW_DMABUF => { + let ctrl = reader + .read_obj::<CtrlVfdNewDmabuf>() + .map_err(WlError::ParseDesc)?; + self.new_dmabuf( + ctrl.id.into(), + ctrl.width.into(), + ctrl.height.into(), + ctrl.format.into(), + ) + } + #[cfg(feature = "wl-dmabuf")] + VIRTIO_WL_CMD_VFD_DMABUF_SYNC => { + let ctrl = reader + .read_obj::<CtrlVfdDmabufSync>() + .map_err(WlError::ParseDesc)?; + self.dmabuf_sync(ctrl.id.into(), ctrl.flags.into()) + } + VIRTIO_WL_CMD_VFD_NEW_CTX_NAMED => { + let ctrl = reader + .read_obj::<CtrlVfdNewCtxNamed>() + .map_err(WlError::ParseDesc)?; + let name_len = ctrl + .name + .iter() + .position(|x| x == &0) + .unwrap_or(ctrl.name.len()); + let name = + std::str::from_utf8(&ctrl.name[..name_len]).map_err(WlError::InvalidString)?; + self.new_context(ctrl.id.into(), name) + } + op_type => { + warn!("unexpected command {}", op_type); + Ok(WlResp::InvalidCommand) + } + } + } + + fn next_recv(&self) -> Option<WlResp> { + if let Some(q) = self.in_queue.front() { + match *q { + (vfd_id, WlRecv::Vfd { id }) => { + if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) { + match self.vfds.get(&id) { + Some(vfd) => Some(WlResp::VfdNew { + id, + flags: vfd.flags(self.use_transition_flags), + pfn: vfd.pfn().unwrap_or_default(), + size: vfd.size().unwrap_or_default() as u32, + resp: false, + }), + _ => Some(WlResp::VfdNew { + id, + flags: 0, + pfn: 0, + size: 0, + resp: false, + }), + } + } else { + Some(WlResp::VfdRecv { + id: self.current_recv_vfd.unwrap(), + data: &[], + vfds: &self.recv_vfds[..], + }) + } + } + (vfd_id, WlRecv::Data { ref buf }) => { + if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) { + Some(WlResp::VfdRecv { + id: vfd_id, + data: &buf[..], + vfds: &self.recv_vfds[..], + }) + } else { + Some(WlResp::VfdRecv { + id: self.current_recv_vfd.unwrap(), + data: &[], + vfds: &self.recv_vfds[..], + }) + } + } + (vfd_id, WlRecv::Hup) => Some(WlResp::VfdHup { id: vfd_id }), + } + } else { + None + } + } + + fn pop_recv(&mut self) { + if let Some(q) = self.in_queue.front() { + match *q { + (vfd_id, WlRecv::Vfd { id }) => { + if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) { + self.recv_vfds.push(id); + self.current_recv_vfd = Some(vfd_id); + } else { + self.recv_vfds.clear(); + self.current_recv_vfd = None; + return; + } + } + (vfd_id, WlRecv::Data { .. }) => { + self.recv_vfds.clear(); + self.current_recv_vfd = None; + if !(self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id)) { + return; + } + } + (_, WlRecv::Hup) => { + self.recv_vfds.clear(); + self.current_recv_vfd = None; + } + } + } + self.in_queue.pop_front(); + } +} + +pub struct Worker { + interrupt: Interrupt, + mem: GuestMemory, + in_queue: Queue, + out_queue: Queue, + state: WlState, +} + +impl Worker { + pub fn new( + mem: GuestMemory, + interrupt: Interrupt, + in_queue: Queue, + out_queue: Queue, + wayland_paths: Map<String, PathBuf>, + vm_socket: VmMemoryControlRequestSocket, + use_transition_flags: bool, + resource_bridge: Option<ResourceRequestSocket>, + ) -> Worker { + Worker { + interrupt, + mem, + in_queue, + out_queue, + state: WlState::new( + wayland_paths, + vm_socket, + use_transition_flags, + resource_bridge, + ), + } + } + + pub fn run(&mut self, mut queue_evts: Vec<EventFd>, kill_evt: EventFd) { + let mut in_desc_chains: VecDeque<DescriptorChain> = + VecDeque::with_capacity(QUEUE_SIZE as usize); + let in_queue_evt = queue_evts.remove(0); + let out_queue_evt = queue_evts.remove(0); + #[derive(Debug, PollToken)] + enum Token { + InQueue, + OutQueue, + Kill, + State, + InterruptResample, + } + + let poll_ctx: PollContext<Token> = match PollContext::build_with(&[ + (&in_queue_evt, Token::InQueue), + (&out_queue_evt, Token::OutQueue), + (&kill_evt, Token::Kill), + (&self.state.poll_ctx, Token::State), + (self.interrupt.get_resample_evt(), Token::InterruptResample), + ]) { + Ok(pc) => pc, + Err(e) => { + error!("failed creating PollContext: {}", e); + return; + } + }; + + 'poll: loop { + let mut signal_used_in = false; + let mut signal_used_out = false; + let events = match poll_ctx.wait() { + Ok(v) => v, + Err(e) => { + error!("failed polling for events: {}", e); + break; + } + }; + + for event in &events { + dbg!(event.token()); + match event.token() { + Token::InQueue => { + let _ = in_queue_evt.read(); + // Used to buffer descriptor indexes that are invalid for our uses. + let mut rejects = [0u16; QUEUE_SIZE as usize]; + let mut rejects_len = 0; + let min_in_desc_len = (size_of::<CtrlVfdRecv>() + + size_of::<Le32>() * VIRTWL_SEND_MAX_ALLOCS) + as u32; + in_desc_chains.extend(self.in_queue.iter(&self.mem).filter_map(|d| { + if d.len >= min_in_desc_len && d.is_write_only() { + Some(d) + } else { + // Can not use queue.add_used directly because it's being borrowed + // for the iterator chain, so we buffer the descriptor index in + // rejects. + rejects[rejects_len] = d.index; + rejects_len += 1; + None + } + })); + for &reject in &rejects[..rejects_len] { + signal_used_in = true; + self.in_queue.add_used(&self.mem, reject, 0); + } + } + Token::OutQueue => { + let _ = out_queue_evt.read(); + while let Some(desc) = self.out_queue.pop(&self.mem) { + let desc_index = desc.index; + match ( + Reader::new(&self.mem, desc.clone()), + Writer::new(&self.mem, desc), + ) { + (Ok(mut reader), Ok(mut writer)) => { + let resp = match self.state.execute(&mut reader) { + Ok(r) => r, + Err(e) => WlResp::Err(Box::new(e)), + }; + + match encode_resp(&mut writer, resp) { + Ok(()) => {} + Err(e) => { + error!( + "failed to encode response to descriptor chain: {}", + e + ); + } + } + + self.out_queue.add_used( + &self.mem, + desc_index, + writer.bytes_written() as u32, + ); + signal_used_out = true; + } + (_, Err(e)) | (Err(e), _) => { + error!("invalid descriptor: {}", e); + self.out_queue.add_used(&self.mem, desc_index, 0); + signal_used_out = true; + } + } + } + } + Token::Kill => break 'poll, + Token::State => self.state.process_poll_context(), + Token::InterruptResample => { + self.interrupt.interrupt_resample(); + } + } + } + + // Because this loop should be retried after the in queue is usable or after one of the + // VFDs was read, we do it after the poll event responses. + while !in_desc_chains.is_empty() { + let mut should_pop = false; + if let Some(in_resp) = self.state.next_recv() { + // in_desc_chains is not empty (checked by loop condition) so unwrap is safe. + let desc = in_desc_chains.pop_front().unwrap(); + let index = desc.index; + match Writer::new(&self.mem, desc) { + Ok(mut writer) => { + match encode_resp(&mut writer, in_resp) { + Ok(()) => { + should_pop = true; + } + Err(e) => { + error!("failed to encode response to descriptor chain: {}", e); + } + }; + signal_used_in = true; + self.in_queue + .add_used(&self.mem, index, writer.bytes_written() as u32); + } + Err(e) => { + error!("invalid descriptor: {}", e); + self.in_queue.add_used(&self.mem, index, 0); + signal_used_in = true; + } + } + } else { + break; + } + if should_pop { + self.state.pop_recv(); + } + } + + if signal_used_in { + self.interrupt.signal_used_queue(self.in_queue.vector); + } + + if signal_used_out { + self.interrupt.signal_used_queue(self.out_queue.vector); + } + } + } +} + +pub struct Wl { + kill_evt: Option<EventFd>, + worker_thread: Option<thread::JoinHandle<()>>, + wayland_paths: Map<String, PathBuf>, + vm_socket: Option<VmMemoryControlRequestSocket>, + resource_bridge: Option<ResourceRequestSocket>, + use_transition_flags: bool, +} + +impl Wl { + pub fn new( + wayland_paths: Map<String, PathBuf>, + vm_socket: VmMemoryControlRequestSocket, + resource_bridge: Option<ResourceRequestSocket>, + ) -> Result<Wl> { + Ok(Wl { + kill_evt: None, + worker_thread: None, + wayland_paths, + vm_socket: Some(vm_socket), + resource_bridge, + use_transition_flags: false, + }) + } +} + +impl Drop for Wl { + fn drop(&mut self) { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(worker_thread) = self.worker_thread.take() { + let _ = worker_thread.join(); + } + } +} + +impl VirtioDevice for Wl { + fn keep_fds(&self) -> Vec<RawFd> { + let mut keep_fds = Vec::new(); + + if let Some(vm_socket) = &self.vm_socket { + keep_fds.push(vm_socket.as_raw_fd()); + } + if let Some(resource_bridge) = &self.resource_bridge { + keep_fds.push(resource_bridge.as_raw_fd()); + } + + keep_fds + } + + fn device_type(&self) -> u32 { + TYPE_WL + } + + fn queue_max_sizes(&self) -> &[u16] { + QUEUE_SIZES + } + + fn features(&self) -> u64 { + 1 << VIRTIO_WL_F_TRANS_FLAGS | 1 << VIRTIO_F_VERSION_1 + } + + fn ack_features(&mut self, value: u64) { + if value & (1 << VIRTIO_WL_F_TRANS_FLAGS) != 0 { + self.use_transition_flags = true; + } + } + + fn activate( + &mut self, + mem: GuestMemory, + interrupt: Interrupt, + mut queues: Vec<Queue>, + queue_evts: Vec<EventFd>, + ) { + if queues.len() != QUEUE_SIZES.len() || queue_evts.len() != QUEUE_SIZES.len() { + return; + } + + let (self_kill_evt, kill_evt) = match EventFd::new().and_then(|e| Ok((e.try_clone()?, e))) { + Ok(v) => v, + Err(e) => { + error!("failed creating kill EventFd pair: {}", e); + return; + } + }; + self.kill_evt = Some(self_kill_evt); + + if let Some(vm_socket) = self.vm_socket.take() { + let wayland_paths = self.wayland_paths.clone(); + let use_transition_flags = self.use_transition_flags; + let resource_bridge = self.resource_bridge.take(); + println!("creating worker"); + let worker_result = + thread::Builder::new() + .name("virtio_wl".to_string()) + .spawn(move || { + Worker::new( + mem, + interrupt, + queues.remove(0), + queues.remove(0), + wayland_paths, + vm_socket, + use_transition_flags, + resource_bridge, + ) + .run(queue_evts, kill_evt); + }); + + match worker_result { + Err(e) => { + error!("failed to spawn virtio_wl worker: {}", e); + return; + } + Ok(join_handle) => { + self.worker_thread = Some(join_handle); + } + } + } + } +} |