// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! This module implements the virtio wayland used by the guest to access the host's wayland server.
//!
//! The virtio wayland protocol is done over two queues: `in` and `out`. The `in` queue is used for
//! sending commands to the guest that are generated by the host, usually messages from the wayland
//! server. The `out` queue is for commands from the guest, usually requests to allocate shared
//! memory, open a wayland server connection, or send data over an existing connection.
//!
//! Each `WlVfd` represents one virtual file descriptor created by either the guest or the host.
//! Virtual file descriptors contain actual file descriptors, either a shared memory file descriptor
//! or a unix domain socket to the wayland server. In the shared memory case, there is also an
//! associated slot that indicates which KVM memory slot the memory is installed into, as well as a
//! page frame number that the guest can access the memory from.
//!
//! The types starting with `Ctrl` are structures representing the virtio wayland protocol "on the
//! wire." They are decoded and executed in the `execute` function and encoded as some variant of
//! `WlResp` for responses.
//!
//! There is one `WlState` instance that contains every known vfd and the current state of `in`
//! queue. The `in` queue requires extra state to buffer messages to the guest in case the `in`
//! queue is already full. The `WlState` also has a control socket necessary to fulfill certain
//! requests, such as those registering guest memory.
//!
//! The `Worker` is responsible for the poll loop over all possible events, encoding/decoding from
//! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
//! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
use std::borrow::Cow;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap as Map, BTreeSet as Set, VecDeque};
use std::convert::{From, Infallible};
use std::error::Error as StdError;
use std::fmt::{self, Display};
use std::fs::File;
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::mem::size_of;
#[cfg(feature = "wl-dmabuf")]
use std::os::raw::{c_uint, c_ulonglong};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::thread;
use std::time::Duration;
#[cfg(feature = "wl-dmabuf")]
use libc::{EBADF, EINVAL};
use data_model::VolatileMemoryError;
use data_model::*;
use msg_socket::{MsgError, MsgReceiver, MsgSender};
use msg_socket2::de::{EnumAccessWithFds, VariantAccessWithFds, VisitorWithFds};
use msg_socket2::ser::SerializeAdapter;
use msg_socket2::Deserialize;
#[cfg(feature = "wl-dmabuf")]
use resources::GpuMemoryDesc;
#[cfg(feature = "wl-dmabuf")]
use sys_util::ioctl_iow_nr;
use sys_util::{
error, pipe, round_up_to_page_size, warn, Error, EventFd, FileFlags, GuestMemory,
GuestMemoryError, PollContext, PollToken, ScmSocket, SharedMemory,
};
#[cfg(feature = "wl-dmabuf")]
use sys_util::ioctl_with_ref;
use super::resource_bridge::*;
use super::{
DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_WL, VIRTIO_F_VERSION_1,
};
use vm_control::{
MaybeOwnedFd, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse,
WlControlCommand, WlControlResponseSocket, WlControlResult,
};
const VIRTWL_SEND_MAX_ALLOCS: usize = 28;
const VIRTIO_WL_CMD_VFD_NEW: u32 = 256;
const VIRTIO_WL_CMD_VFD_CLOSE: u32 = 257;
const VIRTIO_WL_CMD_VFD_SEND: u32 = 258;
const VIRTIO_WL_CMD_VFD_RECV: u32 = 259;
const VIRTIO_WL_CMD_VFD_NEW_CTX: u32 = 260;
const VIRTIO_WL_CMD_VFD_NEW_PIPE: u32 = 261;
const VIRTIO_WL_CMD_VFD_HUP: u32 = 262;
#[cfg(feature = "wl-dmabuf")]
const VIRTIO_WL_CMD_VFD_NEW_DMABUF: u32 = 263;
#[cfg(feature = "wl-dmabuf")]
const VIRTIO_WL_CMD_VFD_DMABUF_SYNC: u32 = 264;
#[cfg(feature = "gpu")]
const VIRTIO_WL_CMD_VFD_SEND_FOREIGN_ID: u32 = 265;
const VIRTIO_WL_CMD_VFD_NEW_CTX_NAMED: u32 = 266;
const VIRTIO_WL_RESP_OK: u32 = 4096;
const VIRTIO_WL_RESP_VFD_NEW: u32 = 4097;
#[cfg(feature = "wl-dmabuf")]
const VIRTIO_WL_RESP_VFD_NEW_DMABUF: u32 = 4098;
const VIRTIO_WL_RESP_ERR: u32 = 4352;
const VIRTIO_WL_RESP_OUT_OF_MEMORY: u32 = 4353;
const VIRTIO_WL_RESP_INVALID_ID: u32 = 4354;
const VIRTIO_WL_RESP_INVALID_TYPE: u32 = 4355;
const VIRTIO_WL_RESP_INVALID_FLAGS: u32 = 4356;
const VIRTIO_WL_RESP_INVALID_CMD: u32 = 4357;
const VIRTIO_WL_VFD_WRITE: u32 = 0x1;
const VIRTIO_WL_VFD_READ: u32 = 0x2;
const VIRTIO_WL_VFD_MAP: u32 = 0x2;
const VIRTIO_WL_VFD_CONTROL: u32 = 0x4;
const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
const QUEUE_SIZE: u16 = 16;
const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE];
const NEXT_VFD_ID_BASE: u32 = 0x40000000;
const VFD_ID_HOST_MASK: u32 = NEXT_VFD_ID_BASE;
// Each in-vq buffer is one page, so we need to leave space for the control header and the maximum
// number of allocs.
const IN_BUFFER_LEN: usize =
0x1000 - size_of::<CtrlVfdRecv>() - VIRTWL_SEND_MAX_ALLOCS * size_of::<Le32>();
#[cfg(feature = "wl-dmabuf")]
const VIRTIO_WL_VFD_DMABUF_SYNC_VALID_FLAG_MASK: u32 = 0x7;
#[cfg(feature = "wl-dmabuf")]
const DMA_BUF_IOCTL_BASE: c_uint = 0x62;
#[cfg(feature = "wl-dmabuf")]
#[repr(C)]
#[derive(Copy, Clone)]
struct dma_buf_sync {
flags: c_ulonglong,
}
#[cfg(feature = "wl-dmabuf")]
ioctl_iow_nr!(DMA_BUF_IOCTL_SYNC, DMA_BUF_IOCTL_BASE, 0, dma_buf_sync);
const VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL: u32 = 0;
const VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU: u32 = 1;
/// If `name` contains "%d", tries replacing "%d" with successive
/// integers starting from 0, until name satisfies predicate.
///
/// `name` can only contain one "%", and it must be followed by "d".
fn unique_name<'a, S, P>(name: S, predicate: P) -> Option<Vec<u8>>
where
S: Into<Cow<'a, [u8]>>,
P: Fn(&[u8]) -> bool,
{
let name = name.into();
if let Some(pos) = name.iter().position(|b| *b == b'%') {
if name.get(pos + 1) != Some(&b'd') {
None
} else if name[(pos + 1)..].contains(&b'%') {
None
} else {
let mut i = 0;
let mut resolved_name = Vec::with_capacity(name.len());
loop {
resolved_name.clear();
resolved_name.extend(&name[0..pos]);
resolved_name.extend(i.to_string().bytes());
resolved_name.extend(&name[(pos + 2)..]);
eprintln!("Trying {}", String::from_utf8_lossy(&resolved_name));
if predicate(&resolved_name) {
break;
}
i += 1;
}
Some(resolved_name)
}
} else {
Some(name.into_owned())
}
}
#[test]
fn test_unique_name() {
assert_eq!(unique_name(&b"%d"[..], |_| true).unwrap(), b"0");
assert_eq!(
unique_name(&b"-%d-"[..], |name| name != b"-0-").unwrap(),
b"-1-"
);
assert_eq!(unique_name(&b"%"[..], |_| true), None);
assert_eq!(unique_name(&b"%%"[..], |_| true), None);
assert_eq!(unique_name(&b"%d%d"[..], |_| true), None);
}
fn encode_vfd_new(
writer: &mut Writer,
resp: bool,
vfd_id: u32,
flags: u32,
pfn: u64,
size: u32,
) -> WlResult<()> {
let ctrl_vfd_new = CtrlVfdNew {
hdr: CtrlHeader {
type_: Le32::from(if resp {
VIRTIO_WL_RESP_VFD_NEW
} else {
VIRTIO_WL_CMD_VFD_NEW
}),
flags: Le32::from(0),
},
id: Le32::from(vfd_id),
flags: Le32::from(flags),
pfn: Le64::from(pfn),
size: Le32::from(size),
};
writer
.write_obj(ctrl_vfd_new)
.map_err(WlError::WriteResponse)
}
#[cfg(feature = "wl-dmabuf")]
fn encode_vfd_new_dmabuf(
writer: &mut Writer,
vfd_id: u32,
flags: u32,
pfn: u64,
size: u32,
desc: GpuMemoryDesc,
) -> WlResult<()> {
let ctrl_vfd_new_dmabuf = CtrlVfdNewDmabuf {
hdr: CtrlHeader {
type_: Le32::from(VIRTIO_WL_RESP_VFD_NEW_DMABUF),
flags: Le32::from(0),
},
id: Le32::from(vfd_id),
flags: Le32::from(flags),
pfn: Le64::from(pfn),
size: Le32::from(size),
width: Le32::from(0),
height: Le32::from(0),
format: Le32::from(0),
stride0: Le32::from(desc.planes[0].stride),
stride1: Le32::from(desc.planes[1].stride),
stride2: Le32::from(desc.planes[2].stride),
offset0: Le32::from(desc.planes[0].offset),
offset1: Le32::from(desc.planes[1].offset),
offset2: Le32::from(desc.planes[2].offset),
};
writer
.write_obj(ctrl_vfd_new_dmabuf)
.map_err(WlError::WriteResponse)
}
fn encode_vfd_recv(writer: &mut Writer, vfd_id: u32, data: &[u8], vfd_ids: &[u32]) -> WlResult<()> {
let ctrl_vfd_recv = CtrlVfdRecv {
hdr: CtrlHeader {
type_: Le32::from(VIRTIO_WL_CMD_VFD_RECV),
flags: Le32::from(0),
},
id: Le32::from(vfd_id),
vfd_count: Le32::from(vfd_ids.len() as u32),
};
writer
.write_obj(ctrl_vfd_recv)
.map_err(WlError::WriteResponse)?;
for &recv_vfd_id in vfd_ids.iter() {
writer
.write_obj(Le32::from(recv_vfd_id))
.map_err(WlError::WriteResponse)?;
}
writer.write_all(data).map_err(WlError::WriteResponse)
}
fn encode_vfd_hup(writer: &mut Writer, vfd_id: u32) -> WlResult<()> {
let ctrl_vfd_new = CtrlVfd {
hdr: CtrlHeader {
type_: Le32::from(VIRTIO_WL_CMD_VFD_HUP),
flags: Le32::from(0),
},
id: Le32::from(vfd_id),
};
writer
.write_obj(ctrl_vfd_new)
.map_err(WlError::WriteResponse)
}
fn encode_resp(writer: &mut Writer, resp: WlResp) -> WlResult<()> {
match resp {
WlResp::VfdNew {
id,
flags,
pfn,
size,
resp,
} => encode_vfd_new(writer, resp, id, flags, pfn, size),
#[cfg(feature = "wl-dmabuf")]
WlResp::VfdNewDmabuf {
id,
flags,
pfn,
size,
desc,
} => encode_vfd_new_dmabuf(writer, id, flags, pfn, size, desc),
WlResp::VfdRecv { id, data, vfds } => encode_vfd_recv(writer, id, data, vfds),
WlResp::VfdHup { id } => encode_vfd_hup(writer, id),
r => writer
.write_obj(Le32::from(r.get_code()))
.map_err(WlError::WriteResponse),
}
}
#[allow(dead_code)]
#[derive(Debug)]
enum WlError {
NewAlloc(Error),
NewPipe(Error),
AllocSetSize(Error),
SocketConnect { path: PathBuf, inner: io::Error },
SocketNonBlock(io::Error),
VmControl(MsgError),
VmBadResponse,
CheckedOffset,
ParseDesc(io::Error),
GuestMemory(GuestMemoryError),
VolatileMemory(VolatileMemoryError),
SendVfd(Error),
WritePipe(io::Error),
RecvVfd(Error),
ReadPipe(io::Error),
PollContextAdd(Error),
DmabufSync(io::Error),
WriteResponse(io::Error),
UnknownSocketName(Vec<u8>),
}
impl Display for WlError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::WlError::*;
match self {
NewAlloc(e) => write!(f, "failed to create shared memory allocation: {}", e),
NewPipe(e) => write!(f, "failed to create pipe: {}", e),
AllocSetSize(e) => write!(f, "failed to set size of shared memory: {}", e),
SocketConnect { path, inner } => {
write!(f, "failed to connect socket at {:?}: {}", path, inner)
}
SocketNonBlock(e) => write!(f, "failed to set socket as non-blocking: {}", e),
VmControl(e) => write!(f, "failed to control parent VM: {}", e),
VmBadResponse => write!(f, "invalid response from parent VM"),
CheckedOffset => write!(f, "overflow in calculation"),
ParseDesc(e) => write!(f, "error parsing descriptor: {}", e),
GuestMemory(e) => write!(f, "access violation in guest memory: {}", e),
VolatileMemory(e) => write!(f, "access violating in guest volatile memory: {}", e),
SendVfd(e) => write!(f, "failed to send on a socket: {}", e),
WritePipe(e) => write!(f, "failed to write to a pipe: {}", e),
RecvVfd(e) => write!(f, "failed to recv on a socket: {}", e),
ReadPipe(e) => write!(f, "failed to read a pipe: {}", e),
PollContextAdd(e) => write!(f, "failed to listen to FD on poll context: {}", e),
DmabufSync(e) => write!(f, "failed to synchronize DMABuf access: {}", e),
WriteResponse(e) => write!(f, "failed to write response: {}", e),
UnknownSocketName(name) => write!(
f,
"unknown socket name: '{}'",
String::from_utf8_lossy(name)
),
}
}
}
impl std::error::Error for WlError {}
type WlResult<T> = Result<T, WlError>;
impl From<GuestMemoryError> for WlError {
fn from(e: GuestMemoryError) -> WlError {
WlError::GuestMemory(e)
}
}
impl From<VolatileMemoryError> for WlError {
fn from(e: VolatileMemoryError) -> WlError {
WlError::VolatileMemory(e)
}
}
#[derive(Clone)]
struct VmRequester {
vm_socket: Rc<VmMemoryControlRequestSocket>,
}
impl VmRequester {
fn new(vm_socket: VmMemoryControlRequestSocket) -> VmRequester {
VmRequester {
vm_socket: Rc::new(vm_socket),
}
}
fn request(&self, request: VmMemoryRequest) -> WlResult<VmMemoryResponse> {
self.vm_socket.send(&request).map_err(WlError::VmControl)?;
self.vm_socket.recv().map_err(WlError::VmControl)
}
}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlHeader {
type_: Le32,
flags: Le32,
}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlVfdNew {
hdr: CtrlHeader,
id: Le32,
flags: Le32,
pfn: Le64,
size: Le32,
}
unsafe impl DataInit for CtrlVfdNew {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlVfdNewCtxNamed {
hdr: CtrlHeader,
id: Le32,
flags: Le32, // Ignored.
pfn: Le64, // Ignored.
size: Le32, // Ignored.
name: [u8; 32],
}
unsafe impl DataInit for CtrlVfdNewCtxNamed {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
#[cfg(feature = "wl-dmabuf")]
struct CtrlVfdNewDmabuf {
hdr: CtrlHeader,
id: Le32,
flags: Le32,
pfn: Le64,
size: Le32,
width: Le32,
height: Le32,
format: Le32,
stride0: Le32,
stride1: Le32,
stride2: Le32,
offset0: Le32,
offset1: Le32,
offset2: Le32,
}
#[cfg(feature = "wl-dmabuf")]
unsafe impl DataInit for CtrlVfdNewDmabuf {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
#[cfg(feature = "wl-dmabuf")]
struct CtrlVfdDmabufSync {
hdr: CtrlHeader,
id: Le32,
flags: Le32,
}
#[cfg(feature = "wl-dmabuf")]
unsafe impl DataInit for CtrlVfdDmabufSync {}
#[repr(C)]
#[derive(Copy, Clone)]
struct CtrlVfdRecv {
hdr: CtrlHeader,
id: Le32,
vfd_count: Le32,
}
unsafe impl DataInit for CtrlVfdRecv {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlVfd {
hdr: CtrlHeader,
id: Le32,
}
unsafe impl DataInit for CtrlVfd {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlVfdSend {
hdr: CtrlHeader,
id: Le32,
vfd_count: Le32,
// Remainder is an array of vfd_count IDs followed by data.
}
unsafe impl DataInit for CtrlVfdSend {}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct CtrlVfdSendVfd {
kind: Le32,
id: Le32,
}
unsafe impl DataInit for CtrlVfdSendVfd {}
#[derive(Debug)]
#[allow(dead_code)]
enum WlResp<'a> {
Ok,
VfdNew {
id: u32,
flags: u32,
pfn: u64,
size: u32,
// The VfdNew variant can be either a response or a command depending on this `resp`. This
// is important for the `get_code` method.
resp: bool,
},
#[cfg(feature = "wl-dmabuf")]
VfdNewDmabuf {
id: u32,
flags: u32,
pfn: u64,
size: u32,
desc: GpuMemoryDesc,
},
VfdRecv {
id: u32,
data: &'a [u8],
vfds: &'a [u32],
},
VfdHup {
id: u32,
},
Err(Box<dyn StdError>),
OutOfMemory,
InvalidId,
InvalidType,
InvalidFlags,
InvalidCommand,
}
impl<'a> WlResp<'a> {
fn get_code(&self) -> u32 {
match *self {
WlResp::Ok => VIRTIO_WL_RESP_OK,
WlResp::VfdNew { resp, .. } => {
if resp {
VIRTIO_WL_RESP_VFD_NEW
} else {
VIRTIO_WL_CMD_VFD_NEW
}
}
#[cfg(feature = "wl-dmabuf")]
WlResp::VfdNewDmabuf { .. } => VIRTIO_WL_RESP_VFD_NEW_DMABUF,
WlResp::VfdRecv { .. } => VIRTIO_WL_CMD_VFD_RECV,
WlResp::VfdHup { .. } => VIRTIO_WL_CMD_VFD_HUP,
WlResp::Err(_) => VIRTIO_WL_RESP_ERR,
WlResp::OutOfMemory => VIRTIO_WL_RESP_OUT_OF_MEMORY,
WlResp::InvalidId => VIRTIO_WL_RESP_INVALID_ID,
WlResp::InvalidType => VIRTIO_WL_RESP_INVALID_TYPE,
WlResp::InvalidFlags => VIRTIO_WL_RESP_INVALID_FLAGS,
WlResp::InvalidCommand => VIRTIO_WL_RESP_INVALID_CMD,
}
}
}
#[derive(Default)]
struct WlVfd {
socket: Option<UnixStream>,
guest_shared_memory: Option<(u64 /* size */, File)>,
remote_pipe: Option<File>,
local_pipe: Option<(u32 /* flags */, File)>,
slot: Option<(u32 /* slot */, u64 /* pfn */, VmRequester)>,
#[cfg(feature = "wl-dmabuf")]
is_dmabuf: bool,
}
impl fmt::Debug for WlVfd {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WlVfd {{")?;
if let Some(s) = &self.socket {
write!(f, " socket: {}", s.as_raw_fd())?;
}
if let Some((slot, pfn, _)) = &self.slot {
write!(f, " slot: {} pfn: {}", slot, pfn)?;
}
if let Some(s) = &self.remote_pipe {
write!(f, " remote: {}", s.as_raw_fd())?;
}
if let Some((_, s)) = &self.local_pipe {
write!(f, " local: {}", s.as_raw_fd())?;
}
write!(f, " }}")
}
}
impl WlVfd {
fn connect<P: AsRef<Path>>(path: P) -> WlResult<WlVfd> {
let path = path.as_ref();
let socket = UnixStream::connect(path).map_err(|e| WlError::SocketConnect {
path: path.to_path_buf(),
inner: e,
})?;
let mut vfd = WlVfd::default();
vfd.socket = Some(socket);
Ok(vfd)
}
fn from_socket(socket: UnixStream) -> WlVfd {
let mut vfd = WlVfd::default();
vfd.socket = Some(socket);
vfd
}
fn allocate(vm: VmRequester, size: u64) -> WlResult<WlVfd> {
let size_page_aligned = round_up_to_page_size(size as usize) as u64;
let mut vfd_shm = SharedMemory::named("virtwl_alloc").map_err(WlError::NewAlloc)?;
vfd_shm
.set_size(size_page_aligned)
.map_err(WlError::AllocSetSize)?;
let register_response = vm.request(VmMemoryRequest::RegisterMemory(
MaybeOwnedFd::Borrowed(vfd_shm.as_raw_fd()),
vfd_shm.size() as usize,
))?;
match register_response {
VmMemoryResponse::RegisterMemory { pfn, slot } => {
let mut vfd = WlVfd::default();
vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm.into()));
vfd.slot = Some((slot, pfn, vm));
Ok(vfd)
}
_ => Err(WlError::VmBadResponse),
}
}
#[cfg(feature = "wl-dmabuf")]
fn dmabuf(
vm: VmRequester,
width: u32,
height: u32,
format: u32,
) -> WlResult<(WlVfd, GpuMemoryDesc)> {
let allocate_and_register_gpu_memory_response =
vm.request(VmMemoryRequest::AllocateAndRegisterGpuMemory {
width,
height,
format,
})?;
match allocate_and_register_gpu_memory_response {
VmMemoryResponse::AllocateAndRegisterGpuMemory {
fd: MaybeOwnedFd::Owned(file),
pfn,
slot,
desc,
} => {
let mut vfd = WlVfd::default();
let vfd_shm = SharedMemory::from_file(file).map_err(WlError::NewAlloc)?;
vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm.into()));
vfd.slot = Some((slot, pfn, vm));
vfd.is_dmabuf = true;
Ok((vfd, desc))
}
_ => Err(WlError::VmBadResponse),
}
}
#[cfg(feature = "wl-dmabuf")]
fn dmabuf_sync(&self, flags: u32) -> WlResult<()> {
if !self.is_dmabuf {
return Err(WlError::DmabufSync(io::Error::from_raw_os_error(EINVAL)));
}
match &self.guest_shared_memory {
Some((_, fd)) => {
let sync = dma_buf_sync {
flags: flags as u64,
};
// Safe as fd is a valid dmabuf and incorrect flags will return an error.
if unsafe { ioctl_with_ref(fd, DMA_BUF_IOCTL_SYNC(), &sync) } < 0 {
Err(WlError::DmabufSync(io::Error::last_os_error()))
} else {
Ok(())
}
}
None => Err(WlError::DmabufSync(io::Error::from_raw_os_error(EBADF))),
}
}
fn pipe_remote_read_local_write() -> WlResult<WlVfd> {
let (read_pipe, write_pipe) = pipe(true).map_err(WlError::NewPipe)?;
let mut vfd = WlVfd::default();
vfd.remote_pipe = Some(read_pipe);
vfd.local_pipe = Some((VIRTIO_WL_VFD_WRITE, write_pipe));
Ok(vfd)
}
fn pipe_remote_write_local_read() -> WlResult<WlVfd> {
let (read_pipe, write_pipe) = pipe(true).map_err(WlError::NewPipe)?;
let mut vfd = WlVfd::default();
vfd.remote_pipe = Some(write_pipe);
vfd.local_pipe = Some((VIRTIO_WL_VFD_READ, read_pipe));
Ok(vfd)
}
fn from_file(vm: VmRequester, mut fd: File) -> WlResult<WlVfd> {
// We need to determine if the given file is more like shared memory or a pipe/socket. A
// quick and easy check is to seek to the end of the file. If it works we assume it's not a
// pipe/socket because those have no end. We can even use that seek location as an indicator
// for how big the shared memory chunk to map into guest memory is. If seeking to the end
// fails, we assume it's a socket or pipe with read/write semantics.
match fd.seek(SeekFrom::End(0)) {
Ok(fd_size) => {
let size = round_up_to_page_size(fd_size as usize) as u64;
let register_response = vm.request(VmMemoryRequest::RegisterMemory(
MaybeOwnedFd::Borrowed(fd.as_raw_fd()),
size as usize,
))?;
match register_response {
VmMemoryResponse::RegisterMemory { pfn, slot } => {
let mut vfd = WlVfd::default();
vfd.guest_shared_memory = Some((size, fd));
vfd.slot = Some((slot, pfn, vm));
Ok(vfd)
}
_ => Err(WlError::VmBadResponse),
}
}
_ => {
let flags = match FileFlags::from_file(&fd) {
Ok(FileFlags::Read) => VIRTIO_WL_VFD_READ,
Ok(FileFlags::Write) => VIRTIO_WL_VFD_WRITE,
Ok(FileFlags::ReadWrite) => VIRTIO_WL_VFD_READ | VIRTIO_WL_VFD_WRITE,
_ => 0,
};
let mut vfd = WlVfd::default();
vfd.local_pipe = Some((flags, fd));
Ok(vfd)
}
}
}
fn flags(&self, use_transition_flags: bool) -> u32 {
let mut flags = 0;
if use_transition_flags {
if self.socket.is_some() {
flags |= VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ;
}
if let Some((f, _)) = self.local_pipe {
flags |= f;
}
} else {
if self.socket.is_some() {
flags |= VIRTIO_WL_VFD_CONTROL;
}
if self.slot.is_some() {
flags |= VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_MAP
}
}
flags
}
// Page frame number in the guest this VFD was mapped at.
fn pfn(&self) -> Option<u64> {
self.slot.as_ref().map(|s| s.1)
}
// Size in bytes of the shared memory VFD.
fn size(&self) -> Option<u64> {
self.guest_shared_memory.as_ref().map(|&(size, _)| size)
}
// The FD that gets sent if this VFD is sent over a socket.
fn send_fd(&self) -> Option<RawFd> {
self.guest_shared_memory
.as_ref()
.map(|(_, fd)| fd.as_raw_fd())
.or(self.socket.as_ref().map(|s| s.as_raw_fd()))
.or(self.remote_pipe.as_ref().map(|p| p.as_raw_fd()))
}
// The FD that is used for polling for events on this VFD.
fn poll_fd(&self) -> Option<&dyn AsRawFd> {
self.socket
.as_ref()
.map(|s| s as &dyn AsRawFd)
.or(self.local_pipe.as_ref().map(|(_, p)| p as &dyn AsRawFd))
}
// Sends data/files from the guest to the host over this VFD.
fn send(&mut self, fds: &[RawFd], data: &mut Reader) -> WlResult<WlResp> {
if let Some(socket) = &self.socket {
socket
.send_with_fds(data.get_remaining(), fds)
.map_err(WlError::SendVfd)?;
// All remaining data in `data` is now considered consumed.
data.consume(::std::usize::MAX);
Ok(WlResp::Ok)
} else if let Some((_, local_pipe)) = &mut self.local_pipe {
// Impossible to send fds over a simple pipe.
if !fds.is_empty() {
return Ok(WlResp::InvalidType);
}
data.read_to(local_pipe, usize::max_value())
.map_err(WlError::WritePipe)?;
Ok(WlResp::Ok)
} else {
Ok(WlResp::InvalidType)
}
}
// Receives data/files from the host for this VFD and queues it for the guest.
fn recv(&mut self, in_file_queue: &mut Vec<File>) -> WlResult<Vec<u8>> {
if let Some(socket) = self.socket.take() {
let mut buf = vec![0; IN_BUFFER_LEN];
let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS];
// If any errors happen, the socket will get dropped, preventing more reading.
let (len, file_count) = socket
.recv_with_fds(&mut buf[..], &mut fd_buf)
.map_err(WlError::RecvVfd)?;
// If any data gets read, the put the socket back for future recv operations.
if len != 0 || file_count != 0 {
buf.truncate(len);
buf.shrink_to_fit();
self.socket = Some(socket);
// Safe because the first file_counts fds from recv_with_fds are owned by us and
// valid.
in_file_queue.extend(
fd_buf[..file_count]
.iter()
.map(|&fd| unsafe { File::from_raw_fd(fd) }),
);
return Ok(buf);
}
Ok(Vec::new())
} else if let Some((flags, mut local_pipe)) = self.local_pipe.take() {
let mut buf = Vec::new();
buf.resize(IN_BUFFER_LEN, 0);
let len = local_pipe.read(&mut buf[..]).map_err(WlError::ReadPipe)?;
if len != 0 {
buf.truncate(len);
buf.shrink_to_fit();
self.local_pipe = Some((flags, local_pipe));
return Ok(buf);
}
Ok(Vec::new())
} else {
Ok(Vec::new())
}
}
// Called after this VFD is sent over a socket to ensure the local end of the VFD receives hang
// up events.
fn close_remote(&mut self) {
self.remote_pipe = None;
}
fn close(&mut self) -> WlResult<()> {
if let Some((slot, _, vm)) = self.slot.take() {
vm.request(VmMemoryRequest::UnregisterMemory(slot))?;
}
self.socket = None;
self.remote_pipe = None;
self.local_pipe = None;
Ok(())
}
}
impl Drop for WlVfd {
fn drop(&mut self) {
let _ = self.close();
}
}
#[derive(Debug)]
enum WlRecv {
Vfd { id: u32 },
Data { buf: Vec<u8> },
Hup,
}
#[derive(Debug)]
enum WaylandSocket {
Listening(PathBuf),
NonListening(UnixStream),
}
impl SerializeWithFds for WaylandSocket {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use WaylandSocket::*;
match self {
Listening(path) => {
serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path)
}
NonListening(socket) => serializer.serialize_newtype_variant(
"WaylandSocket",
1,
"NonListening",
&SerializeAdapter::new(socket),
),
}
}
fn serialize_fds<'fds, S: FdSerializer<'fds>>(
&'fds self,
serializer: S,
) -> Result<S::Ok, S::Error> {
use WaylandSocket::*;
match self {
Listening(path) => {
serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path)
}
NonListening(socket) => {
serializer.serialize_newtype_variant("WaylandSocket", 1, "NonListening", socket)
}
}
}
}
impl<'de> DeserializeWithFds<'de> for WaylandSocket {
fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> VisitorWithFds<'de> for Visitor {
type Value = WaylandSocket;
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "enum WaylandSocket")
}
fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
where
A: EnumAccessWithFds<'de>,
{
#[derive(Debug, Deserialize)]
enum Variant {
Listening,
NonListening,
}
match data.variant()? {
(Variant::Listening, variant) => variant.newtype_variant(),
(Variant::NonListening, variant) => variant.newtype_variant(),
}
}
}
deserializer.deserialize_enum("WaylandSocket", &["Listening", "NonListening"], Visitor)
}
}
struct WlState {
wayland_sockets: Map<Vec<u8>, WaylandSocket>,
vm: VmRequester,
resource_bridge: Option<ResourceRequestSocket>,
use_transition_flags: bool,
poll_ctx: PollContext<u32>,
vfds: Map<u32, WlVfd>,
next_vfd_id: u32,
in_file_queue: Vec<File>,
in_queue: VecDeque<(u32 /* vfd_id */, WlRecv)>,
current_recv_vfd: Option<u32>,
recv_vfds: Vec<u32>,
}
impl WlState {
fn new(
wayland_sockets: Map<Vec<u8>, WaylandSocket>,
vm_socket: VmMemoryControlRequestSocket,
use_transition_flags: bool,
resource_bridge: Option<ResourceRequestSocket>,
) -> WlState {
WlState {
wayland_sockets,
vm: VmRequester::new(vm_socket),
resource_bridge,
poll_ctx: PollContext::new().expect("failed to create PollContext"),
use_transition_flags,
vfds: Map::new(),
next_vfd_id: NEXT_VFD_ID_BASE,
in_file_queue: Vec::new(),
in_queue: VecDeque::new(),
current_recv_vfd: None,
recv_vfds: Vec::new(),
}
}
fn new_pipe(&mut self, id: u32, flags: u32) -> WlResult<WlResp> {
if id & VFD_ID_HOST_MASK != 0 {
return Ok(WlResp::InvalidId);
}
if flags & !(VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ) != 0 {
return Ok(WlResp::InvalidFlags);
}
if flags & VIRTIO_WL_VFD_WRITE != 0 && flags & VIRTIO_WL_VFD_READ != 0 {
return Ok(WlResp::InvalidFlags);
}
match self.vfds.entry(id) {
Entry::Vacant(entry) => {
let vfd = if flags & VIRTIO_WL_VFD_WRITE != 0 {
WlVfd::pipe_remote_read_local_write()?
} else if flags & VIRTIO_WL_VFD_READ != 0 {
WlVfd::pipe_remote_write_local_read()?
} else {
return Ok(WlResp::InvalidFlags);
};
self.poll_ctx
.add(vfd.poll_fd().unwrap(), id)
.map_err(WlError::PollContextAdd)?;
let resp = WlResp::VfdNew {
id,
flags: 0,
pfn: 0,
size: 0,
resp: true,
};
entry.insert(vfd);
Ok(resp)
}
Entry::Occupied(_) => Ok(WlResp::InvalidId),
}
}
fn new_alloc(&mut self, id: u32, flags: u32, size: u32) -> WlResult<WlResp> {
if id & VFD_ID_HOST_MASK != 0 {
return Ok(WlResp::InvalidId);
}
if self.use_transition_flags {
if flags != 0 {
return Ok(WlResp::InvalidFlags);
}
} else if flags & !(VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_MAP) != 0 {
return Ok(WlResp::Err(Box::from("invalid flags")));
}
match self.vfds.entry(id) {
Entry::Vacant(entry) => {
let vfd = WlVfd::allocate(self.vm.clone(), size as u64)?;
let resp = WlResp::VfdNew {
id,
flags,
pfn: vfd.pfn().unwrap_or_default(),
size: vfd.size().unwrap_or_default() as u32,
resp: true,
};
entry.insert(vfd);
Ok(resp)
}
Entry::Occupied(_) => Ok(WlResp::InvalidId),
}
}
#[cfg(feature = "wl-dmabuf")]
fn new_dmabuf(&mut self, id: u32, width: u32, height: u32, format: u32) -> WlResult<WlResp> {
if id & VFD_ID_HOST_MASK != 0 {
return Ok(WlResp::InvalidId);
}
match self.vfds.entry(id) {
Entry::Vacant(entry) => {
let (vfd, desc) = WlVfd::dmabuf(self.vm.clone(), width, height, format)?;
let resp = WlResp::VfdNewDmabuf {
id,
flags: 0,
pfn: vfd.pfn().unwrap_or_default(),
size: vfd.size().unwrap_or_default() as u32,
desc,
};
entry.insert(vfd);
Ok(resp)
}
Entry::Occupied(_) => Ok(WlResp::InvalidId),
}
}
#[cfg(feature = "wl-dmabuf")]
fn dmabuf_sync(&mut self, vfd_id: u32, flags: u32) -> WlResult<WlResp> {
if flags & !(VIRTIO_WL_VFD_DMABUF_SYNC_VALID_FLAG_MASK) != 0 {
return Ok(WlResp::InvalidFlags);
}
match self.vfds.get_mut(&vfd_id) {
Some(vfd) => {
vfd.dmabuf_sync(flags)?;
Ok(WlResp::Ok)
}
None => Ok(WlResp::InvalidId),
}
}
fn new_context(&mut self, id: u32, name: &[u8]) -> WlResult<WlResp> {
if id & VFD_ID_HOST_MASK != 0 {
return Ok(WlResp::InvalidId);
}
let flags = if self.use_transition_flags {
VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_READ
} else {
VIRTIO_WL_VFD_CONTROL
};
match self.vfds.entry(id) {
Entry::Vacant(entry) => {
let vfd =
if let Some(WaylandSocket::Listening(path)) = self.wayland_sockets.get(name) {
WlVfd::connect(path)?
} else if let Some(WaylandSocket::NonListening(socket)) =
self.wayland_sockets.remove(name)
{
WlVfd::from_socket(socket)
} else {
return Err(WlError::UnknownSocketName(name.to_vec()));
};
let vfd = entry.insert(vfd);
self.poll_ctx
.add(vfd.poll_fd().unwrap(), id)
.map_err(WlError::PollContextAdd)?;
Ok(WlResp::VfdNew {
id,
flags,
pfn: 0,
size: 0,
resp: true,
})
}
Entry::Occupied(_) => Ok(WlResp::InvalidId),
}
}
fn add_socket(&mut self, name: Vec<u8>, socket: UnixStream) -> Result<(), Error> {
if name.len() > 32 {
return Err(Error::new(libc::EINVAL));
}
if self.wayland_sockets.contains_key(&name) {
return Err(Error::new(libc::EADDRINUSE));
}
self.wayland_sockets
.insert(name, WaylandSocket::NonListening(socket));
Ok(())
}
fn process_poll_context(&mut self) {
let events = match self.poll_ctx.wait_timeout(Duration::from_secs(0)) {
Ok(v) => v.to_owned(),
Err(e) => {
error!("failed polling for vfd evens: {}", e);
return;
}
};
for event in events.as_ref().iter_readable() {
if let Err(e) = self.recv(event.token()) {
error!("failed to recv from vfd: {}", e)
}
}
for event in events.as_ref().iter_hungup() {
if !event.readable() {
let vfd_id = event.token();
if let Some(fd) = self.vfds.get(&vfd_id).and_then(|vfd| vfd.poll_fd()) {
if let Err(e) = self.poll_ctx.delete(fd) {
warn!("failed to remove hungup vfd from poll context: {}", e);
}
}
self.in_queue.push_back((vfd_id, WlRecv::Hup));
}
}
}
fn close(&mut self, vfd_id: u32) -> WlResult<WlResp> {
let mut to_delete = Set::new();
for (dest_vfd_id, q) in &self.in_queue {
if *dest_vfd_id == vfd_id {
if let WlRecv::Vfd { id } = q {
to_delete.insert(*id);
}
}
}
for vfd_id in to_delete {
// Sorry sub-error, we can't have cascading errors leaving us in an inconsistent state.
let _ = self.close(vfd_id);
}
match self.vfds.remove(&vfd_id) {
Some(mut vfd) => {
self.in_queue.retain(|&(id, _)| id != vfd_id);
vfd.close()?;
Ok(WlResp::Ok)
}
None => Ok(WlResp::InvalidId),
}
}
fn send(
&mut self,
vfd_id: u32,
vfd_count: usize,
foreign_id: bool,
reader: &mut Reader,
) -> WlResult<WlResp> {
// First stage gathers and normalizes all id information from guest memory.
let mut send_vfd_ids = [CtrlVfdSendVfd::default(); VIRTWL_SEND_MAX_ALLOCS];
for vfd_id in send_vfd_ids.iter_mut().take(vfd_count) {
*vfd_id = if foreign_id {
reader.read_obj().map_err(WlError::ParseDesc)?
} else {
CtrlVfdSendVfd {
kind: Le32::from(VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL),
id: reader.read_obj().map_err(WlError::ParseDesc)?,
}
}
}
// Next stage collects corresponding file descriptors for each id.
let mut fds = [0; VIRTWL_SEND_MAX_ALLOCS];
#[cfg(feature = "gpu")]
let mut bridged_files = Vec::new();
for (&send_vfd_id, fd) in send_vfd_ids[..vfd_count].iter().zip(fds.iter_mut()) {
let id = send_vfd_id.id.to_native();
match send_vfd_id.kind.to_native() {
VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL => match self.vfds.get(&id) {
Some(vfd) => match vfd.send_fd() {
Some(vfd_fd) => *fd = vfd_fd,
None => return Ok(WlResp::InvalidType),
},
None => {
warn!("attempt to send non-existant vfd 0x{:08x}", id);
return Ok(WlResp::InvalidId);
}
},
#[cfg(feature = "gpu")]
VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU if self.resource_bridge.is_some() => {
let sock = self.resource_bridge.as_ref().unwrap();
match get_resource_info(sock, id) {
Ok(info) => {
*fd = info.file.as_raw_fd();
bridged_files.push(info.file);
}
Err(ResourceBridgeError::InvalidResource(id)) => {
warn!("attempt to send non-existent gpu resource {}", id);
return Ok(WlResp::InvalidId);
}
Err(e) => {
error!("{}", e);
// If there was an error with the resource bridge, it can no longer be
// trusted to continue to function.
self.resource_bridge = None;
return Ok(WlResp::InvalidId);
}
}
}
VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU => {
let _ = self.resource_bridge.as_ref();
warn!("attempt to send foreign resource kind but feature is disabled");
}
kind => {
warn!(
"attempt to send unknown foreign resource kind: {} id: {:08x}",
kind, id
);
return Ok(WlResp::InvalidId);
}
}
}
// Final stage sends file descriptors and data to the target vfd's socket.
match self.vfds.get_mut(&vfd_id) {
Some(vfd) => match vfd.send(&fds[..vfd_count], reader)? {
WlResp::Ok => {}
_ => return Ok(WlResp::InvalidType),
},
None => return Ok(WlResp::InvalidId),
}
// The vfds with remote FDs need to be closed so that the local side can receive
// hangup events.
for &send_vfd_id in &send_vfd_ids[..vfd_count] {
if send_vfd_id.kind == VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL {
if let Some(vfd) = self.vfds.get_mut(&send_vfd_id.id.into()) {
vfd.close_remote();
}
}
}
Ok(WlResp::Ok)
}
fn recv(&mut self, vfd_id: u32) -> WlResult<()> {
let buf = match self.vfds.get_mut(&vfd_id) {
Some(vfd) => vfd.recv(&mut self.in_file_queue)?,
None => return Ok(()),
};
if self.in_file_queue.is_empty() && buf.is_empty() {
self.in_queue.push_back((vfd_id, WlRecv::Hup));
return Ok(());
}
for file in self.in_file_queue.drain(..) {
let vfd = WlVfd::from_file(self.vm.clone(), file)?;
if let Some(poll_fd) = vfd.poll_fd() {
self.poll_ctx
.add(poll_fd, self.next_vfd_id)
.map_err(WlError::PollContextAdd)?;
}
self.vfds.insert(self.next_vfd_id, vfd);
self.in_queue.push_back((
vfd_id,
WlRecv::Vfd {
id: self.next_vfd_id,
},
));
self.next_vfd_id += 1;
}
self.in_queue.push_back((vfd_id, WlRecv::Data { buf }));
Ok(())
}
fn execute(&mut self, reader: &mut Reader) -> WlResult<WlResp> {
let type_ = {
let mut type_reader = reader.clone();
type_reader.read_obj::<Le32>().map_err(WlError::ParseDesc)?
};
match type_.into() {
VIRTIO_WL_CMD_VFD_NEW => {
let ctrl = reader
.read_obj::<CtrlVfdNew>()
.map_err(WlError::ParseDesc)?;
self.new_alloc(ctrl.id.into(), ctrl.flags.into(), ctrl.size.into())
}
VIRTIO_WL_CMD_VFD_CLOSE => {
let ctrl = reader.read_obj::<CtrlVfd>().map_err(WlError::ParseDesc)?;
self.close(ctrl.id.into())
}
VIRTIO_WL_CMD_VFD_SEND => {
let ctrl = reader
.read_obj::<CtrlVfdSend>()
.map_err(WlError::ParseDesc)?;
let foreign_id = false;
self.send(
ctrl.id.into(),
ctrl.vfd_count.to_native() as usize,
foreign_id,
reader,
)
}
#[cfg(feature = "gpu")]
VIRTIO_WL_CMD_VFD_SEND_FOREIGN_ID => {
let ctrl = reader
.read_obj::<CtrlVfdSend>()
.map_err(WlError::ParseDesc)?;
let foreign_id = true;
self.send(
ctrl.id.into(),
ctrl.vfd_count.to_native() as usize,
foreign_id,
reader,
)
}
VIRTIO_WL_CMD_VFD_NEW_CTX => {
let ctrl = reader.read_obj::<CtrlVfd>().map_err(WlError::ParseDesc)?;
self.new_context(ctrl.id.into(), b"")
}
VIRTIO_WL_CMD_VFD_NEW_PIPE => {
let ctrl = reader
.read_obj::<CtrlVfdNew>()
.map_err(WlError::ParseDesc)?;
self.new_pipe(ctrl.id.into(), ctrl.flags.into())
}
#[cfg(feature = "wl-dmabuf")]
VIRTIO_WL_CMD_VFD_NEW_DMABUF => {
let ctrl = reader
.read_obj::<CtrlVfdNewDmabuf>()
.map_err(WlError::ParseDesc)?;
self.new_dmabuf(
ctrl.id.into(),
ctrl.width.into(),
ctrl.height.into(),
ctrl.format.into(),
)
}
#[cfg(feature = "wl-dmabuf")]
VIRTIO_WL_CMD_VFD_DMABUF_SYNC => {
let ctrl = reader
.read_obj::<CtrlVfdDmabufSync>()
.map_err(WlError::ParseDesc)?;
self.dmabuf_sync(ctrl.id.into(), ctrl.flags.into())
}
VIRTIO_WL_CMD_VFD_NEW_CTX_NAMED => {
let ctrl = reader
.read_obj::<CtrlVfdNewCtxNamed>()
.map_err(WlError::ParseDesc)?;
let name_len = ctrl
.name
.iter()
.position(|x| x == &0)
.unwrap_or(ctrl.name.len());
let ref name = ctrl.name[..name_len];
self.new_context(ctrl.id.into(), name)
}
op_type => {
warn!("unexpected command {:#x}", op_type);
Ok(WlResp::InvalidCommand)
}
}
}
fn next_recv(&self) -> Option<WlResp> {
if let Some(q) = self.in_queue.front() {
match *q {
(vfd_id, WlRecv::Vfd { id }) => {
if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) {
match self.vfds.get(&id) {
Some(vfd) => Some(WlResp::VfdNew {
id,
flags: vfd.flags(self.use_transition_flags),
pfn: vfd.pfn().unwrap_or_default(),
size: vfd.size().unwrap_or_default() as u32,
resp: false,
}),
_ => Some(WlResp::VfdNew {
id,
flags: 0,
pfn: 0,
size: 0,
resp: false,
}),
}
} else {
Some(WlResp::VfdRecv {
id: self.current_recv_vfd.unwrap(),
data: &[],
vfds: &self.recv_vfds[..],
})
}
}
(vfd_id, WlRecv::Data { ref buf }) => {
if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) {
Some(WlResp::VfdRecv {
id: vfd_id,
data: &buf[..],
vfds: &self.recv_vfds[..],
})
} else {
Some(WlResp::VfdRecv {
id: self.current_recv_vfd.unwrap(),
data: &[],
vfds: &self.recv_vfds[..],
})
}
}
(vfd_id, WlRecv::Hup) => Some(WlResp::VfdHup { id: vfd_id }),
}
} else {
None
}
}
fn pop_recv(&mut self) {
if let Some(q) = self.in_queue.front() {
match *q {
(vfd_id, WlRecv::Vfd { id }) => {
if self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id) {
self.recv_vfds.push(id);
self.current_recv_vfd = Some(vfd_id);
} else {
self.recv_vfds.clear();
self.current_recv_vfd = None;
return;
}
}
(vfd_id, WlRecv::Data { .. }) => {
self.recv_vfds.clear();
self.current_recv_vfd = None;
if !(self.current_recv_vfd.is_none() || self.current_recv_vfd == Some(vfd_id)) {
return;
}
}
(_, WlRecv::Hup) => {
self.recv_vfds.clear();
self.current_recv_vfd = None;
}
}
}
self.in_queue.pop_front();
}
}
struct Worker {
interrupt: Interrupt,
mem: GuestMemory,
in_queue: Queue,
out_queue: Queue,
state: WlState,
control_socket: WlControlResponseSocket,
}
impl Worker {
fn new(
mem: GuestMemory,
interrupt: Interrupt,
in_queue: Queue,
out_queue: Queue,
wayland_sockets: Map<Vec<u8>, WaylandSocket>,
vm_socket: VmMemoryControlRequestSocket,
use_transition_flags: bool,
resource_bridge: Option<ResourceRequestSocket>,
control_socket: WlControlResponseSocket,
) -> Worker {
Worker {
interrupt,
mem,
in_queue,
out_queue,
state: WlState::new(
wayland_sockets,
vm_socket,
use_transition_flags,
resource_bridge,
),
control_socket,
}
}
pub fn run(&mut self, mut queue_evts: Vec<EventFd>, kill_evt: EventFd) {
let mut in_desc_chains: VecDeque<DescriptorChain> =
VecDeque::with_capacity(QUEUE_SIZE as usize);
let in_queue_evt = queue_evts.remove(0);
let out_queue_evt = queue_evts.remove(0);
#[derive(Debug, PollToken)]
enum Token {
InQueue,
OutQueue,
CommandSocket,
Kill,
State,
InterruptResample,
}
let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
(&in_queue_evt, Token::InQueue),
(&out_queue_evt, Token::OutQueue),
(&self.control_socket, Token::CommandSocket),
(&kill_evt, Token::Kill),
(&self.state.poll_ctx, Token::State),
(self.interrupt.get_resample_evt(), Token::InterruptResample),
]) {
Ok(pc) => pc,
Err(e) => {
error!("failed creating PollContext: {}", e);
return;
}
};
if let Err(e) = self.control_socket.send(&WlControlResult::Ready) {
error!("control socket failed to notify readiness: {}", e);
return;
}
'poll: loop {
let mut signal_used_in = false;
let mut signal_used_out = false;
let events = match poll_ctx.wait() {
Ok(v) => v,
Err(e) => {
error!("failed polling for events: {}", e);
break;
}
};
for event in &events {
match dbg!(event.token()) {
Token::InQueue => {
let _ = in_queue_evt.read();
// Used to buffer descriptor indexes that are invalid for our uses.
let mut rejects = [0u16; QUEUE_SIZE as usize];
let mut rejects_len = 0;
let min_in_desc_len = (size_of::<CtrlVfdRecv>()
+ size_of::<Le32>() * VIRTWL_SEND_MAX_ALLOCS)
as u32;
in_desc_chains.extend(self.in_queue.iter(&self.mem).filter_map(|d| {
if d.len >= min_in_desc_len && d.is_write_only() {
Some(d)
} else {
// Can not use queue.add_used directly because it's being borrowed
// for the iterator chain, so we buffer the descriptor index in
// rejects.
rejects[rejects_len] = d.index;
rejects_len += 1;
None
}
}));
for &reject in &rejects[..rejects_len] {
signal_used_in = true;
self.in_queue.add_used(&self.mem, reject, 0);
}
}
Token::OutQueue => {
let _ = out_queue_evt.read();
while let Some(desc) = self.out_queue.pop(&self.mem) {
let desc_index = desc.index;
match (
Reader::new(&self.mem, desc.clone()),
Writer::new(&self.mem, desc),
) {
(Ok(mut reader), Ok(mut writer)) => {
let resp = match self.state.execute(&mut reader) {
Ok(r) => r,
Err(e) => {
error!("{}", e);
WlResp::Err(Box::new(e))
}
};
match encode_resp(&mut writer, resp) {
Ok(()) => {}
Err(e) => {
error!(
"failed to encode response to descriptor chain: {}",
e
);
}
}
self.out_queue.add_used(
&self.mem,
desc_index,
writer.bytes_written() as u32,
);
signal_used_out = true;
}
(_, Err(e)) | (Err(e), _) => {
error!("invalid descriptor: {}", e);
self.out_queue.add_used(&self.mem, desc_index, 0);
signal_used_out = true;
}
}
}
}
Token::CommandSocket => {
let resp: WlControlResult = match self.control_socket.recv() {
Ok(WlControlCommand::AddSocket { name, socket }) => {
match unique_name(Cow::Owned(name), |name| !self
.state
.wayland_sockets
.contains_key(name))
{
Some(name) => match self.state.add_socket(name.clone(), socket)
{
Ok(()) => WlControlResult::SocketAdded(name),
Err(e) => WlControlResult::Err(e),
},
None => WlControlResult::Err(Error::new(libc::EINVAL)),
}
}
Err(MsgError::InvalidData) => {
WlControlResult::Err(Error::new(libc::EINVAL))
}
Err(e) => {
error!("control socket failed recv: {}", e);
break 'poll;
}
};
if let Err(e) = self.control_socket.send(dbg!(&resp)) {
error!("control socket failed send: {}", e);
break 'poll;
}
}
Token::Kill => break 'poll,
Token::State => self.state.process_poll_context(),
Token::InterruptResample => {
self.interrupt.interrupt_resample();
}
}
}
// Because this loop should be retried after the in queue is usable or after one of the
// VFDs was read, we do it after the poll event responses.
while !in_desc_chains.is_empty() {
let mut should_pop = false;
if let Some(in_resp) = self.state.next_recv() {
// in_desc_chains is not empty (checked by loop condition) so unwrap is safe.
let desc = in_desc_chains.pop_front().unwrap();
let index = desc.index;
match Writer::new(&self.mem, desc) {
Ok(mut writer) => {
match encode_resp(&mut writer, in_resp) {
Ok(()) => {
should_pop = true;
}
Err(e) => {
error!("failed to encode response to descriptor chain: {}", e);
}
};
signal_used_in = true;
self.in_queue
.add_used(&self.mem, index, writer.bytes_written() as u32);
}
Err(e) => {
error!("invalid descriptor: {}", e);
self.in_queue.add_used(&self.mem, index, 0);
signal_used_in = true;
}
}
} else {
break;
}
if should_pop {
self.state.pop_recv();
}
}
if signal_used_in {
self.interrupt.signal_used_queue(self.in_queue.vector);
}
if signal_used_out {
self.interrupt.signal_used_queue(self.out_queue.vector);
}
}
}
}
#[derive(Debug)]
pub struct Wl {
kill_evt: Option<EventFd>,
worker_thread: Option<thread::JoinHandle<()>>,
wayland_sockets: Option<Map<Vec<u8>, WaylandSocket>>,
vm_socket: Option<VmMemoryControlRequestSocket>,
resource_bridge: Option<ResourceRequestSocket>,
use_transition_flags: bool,
control_socket: Option<WlControlResponseSocket>,
}
use msg_socket2::{
de::{DeserializeWithFds, DeserializerWithFds},
ser::{FdSerializer, SerializeStruct, SerializeStructFds, SerializeWithFds, Serializer},
};
use std::fmt::Formatter;
use super::VirtioDeviceNew;
#[derive(Debug)]
pub struct Params {
pub wayland_paths: Map<Vec<u8>, PathBuf>,
pub vm_socket: VmMemoryControlRequestSocket,
pub resource_bridge: Option<ResourceRequestSocket>,
pub control_socket: WlControlResponseSocket,
}
impl SerializeWithFds for Params {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("Params", 4)?;
state.serialize_field("wayland_paths", &self.wayland_paths)?;
state.serialize_field("vm_socket", &SerializeAdapter::new(&self.vm_socket))?;
state.serialize_field(
"resource_bridge",
&SerializeAdapter::new(&self.resource_bridge),
)?;
state.serialize_field(
"control_socket",
&SerializeAdapter::new(&self.control_socket),
)?;
state.end()
}
fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error>
where
S: FdSerializer<'fds>,
{
let mut state = serializer.serialize_struct("Params", 4)?;
state.serialize_field("wayland_paths", &self.wayland_paths)?;
state.serialize_field("vm_socket", &self.vm_socket)?;
state.serialize_field("resource_bridge", &self.resource_bridge)?;
state.serialize_field("control_socket", &self.control_socket)?;
state.end()
}
}
use msg_socket::MsgSocket;
use msg_socket2::de::SeqAccessWithFds;
impl<'de> DeserializeWithFds<'de> for Params {
fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> VisitorWithFds<'de> for Visitor {
type Value = Params;
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "struct Params")
}
fn visit_seq<A: SeqAccessWithFds<'de>>(
self,
mut seq: A,
) -> Result<Self::Value, A::Error> {
use serde::de::Error;
fn too_short_error<E: Error>(len: usize) -> E {
E::invalid_length(len, &"struct Params with 4 elements")
}
Ok(Params {
wayland_paths: seq.next_element()?.ok_or_else(|| too_short_error(0))?,
vm_socket: seq
.next_element()?
.map(MsgSocket::new)
.ok_or_else(|| too_short_error(1))?,
resource_bridge: seq
.next_element::<Option<_>>()?
.ok_or_else(|| too_short_error(2))?
.map(MsgSocket::new),
control_socket: seq
.next_element()?
.map(MsgSocket::new)
.ok_or_else(|| too_short_error(3))?,
})
}
}
deserializer.deserialize_struct(
"Params",
&["wayland_sockets", "vm_socket", "resource_bridge"],
Visitor,
)
}
}
impl VirtioDeviceNew for Wl {
type Params = Params;
type Error = Infallible;
fn new(params: Params) -> Result<Self, Self::Error> {
let Params {
wayland_paths,
vm_socket,
resource_bridge,
control_socket,
} = params;
let wayland_sockets = wayland_paths
.into_iter()
.map(|(n, path)| (n, WaylandSocket::Listening(path)))
.collect();
Ok(Self {
kill_evt: None,
worker_thread: None,
wayland_sockets: Some(wayland_sockets),
vm_socket: Some(vm_socket),
resource_bridge,
use_transition_flags: false,
control_socket: Some(control_socket),
})
}
}
impl Drop for Wl {
fn drop(&mut self) {
if let Some(kill_evt) = self.kill_evt.take() {
// Ignore the result because there is nothing we can do about it.
let _ = kill_evt.write(1);
}
if let Some(worker_thread) = self.worker_thread.take() {
let _ = worker_thread.join();
}
}
}
impl VirtioDevice for Wl {
fn keep_fds(&self) -> Vec<RawFd> {
let mut keep_fds = Vec::new();
if let Some(ref wayland_sockets) = self.wayland_sockets {
for (_, socket) in wayland_sockets {
if let WaylandSocket::NonListening(socket) = socket {
keep_fds.push(socket.as_raw_fd());
}
}
}
if let Some(vm_socket) = &self.vm_socket {
keep_fds.push(vm_socket.as_raw_fd());
}
if let Some(resource_bridge) = &self.resource_bridge {
keep_fds.push(resource_bridge.as_raw_fd());
}
if let Some(control_socket) = &self.control_socket {
keep_fds.push(control_socket.as_raw_fd());
}
keep_fds
}
fn device_type(&self) -> u32 {
TYPE_WL
}
fn queue_max_sizes(&self) -> Vec<u16> {
QUEUE_SIZES.to_vec()
}
fn features(&self) -> u64 {
1 << VIRTIO_WL_F_TRANS_FLAGS | 1 << VIRTIO_F_VERSION_1
}
fn ack_features(&mut self, value: u64) {
if value & (1 << VIRTIO_WL_F_TRANS_FLAGS) != 0 {
self.use_transition_flags = true;
}
}
fn activate(
&mut self,
mem: GuestMemory,
interrupt: Interrupt,
mut queues: Vec<Queue>,
queue_evts: Vec<EventFd>,
) {
eprintln!("+++++++ Activating Wl device");
if queues.len() != QUEUE_SIZES.len() || queue_evts.len() != QUEUE_SIZES.len() {
return;
}
let (self_kill_evt, kill_evt) = match EventFd::new().and_then(|e| Ok((e.try_clone()?, e))) {
Ok(v) => v,
Err(e) => {
error!("failed creating kill EventFd pair: {}", e);
return;
}
};
self.kill_evt = Some(self_kill_evt);
if let Some(vm_socket) = self.vm_socket.take() {
let wayland_sockets = self.wayland_sockets.take().unwrap();
let use_transition_flags = self.use_transition_flags;
let resource_bridge = self.resource_bridge.take();
let control_socket = self.control_socket.take().unwrap();
println!("creating worker");
let worker_result =
thread::Builder::new()
.name("virtio_wl".to_string())
.spawn(move || {
Worker::new(
mem,
interrupt,
queues.remove(0),
queues.remove(0),
wayland_sockets,
vm_socket,
use_transition_flags,
resource_bridge,
control_socket,
)
.run(queue_evts, kill_evt);
});
match worker_result {
Err(e) => {
error!("failed to spawn virtio_wl worker: {}", e);
return;
}
Ok(join_handle) => {
self.worker_thread = Some(join_handle);
}
}
}
}
}