// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! This module implements the virtio wayland used by the guest to access the host's wayland server.
//!
//! The virtio wayland protocol is done over two queues: `in` and `out`. The `in` queue is used for
//! sending commands to the guest that are generated by the host, usually messages from the wayland
//! server. The `out` queue is for commands from the guest, usually requests to allocate shared
//! memory, open a wayland server connection, or send data over an existing connection.
//!
//! Each `WlVfd` represents one virtual file descriptor created by either the guest or the host.
//! Virtual file descriptors contain actual file descriptors, either a shared memory file descriptor
//! or a unix domain socket to the wayland server. In the shared memory case, there is also an
//! associated slot that indicates which KVM memory slot the memory is installed into, as well as a
//! page frame number that the guest can access the memory from.
//!
//! The types starting with `Ctrl` are structures representing the virtio wayland protocol "on the
//! wire." They are decoded and executed in the `execute` function and encoded as some variant of
//! `WlResp` for responses.
//!
//! There is one `WlState` instance that contains every known vfd and the current state of `in`
//! queue. The `in` queue requires extra state to buffer messages to the guest in case the `in`
//! queue is already full. The `WlState` also has a control socket necessary to fulfill certain
//! requests, such as those registering guest memory.
//!
//! The `Worker` is responsible for the poll loop over all possible events, encoding/decoding from
//! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
//! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use std::thread;
use msg_socket::{MsgReceiver, MsgSocket};
use sys_util::net::UnixSeqpacket;
use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken};
use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket};
use super::{
remote::{Request, Response},
Interrupt, InterruptProxyEvent, Queue, VirtioDevice,
};
use crate::{
pci::{PciAddress, PciBarConfiguration, PciCapability},
MemoryParams,
};
type Socket = msg_socket2::Socket<Request, Response>;
// TODO: support arbitrary number of queues
const QUEUE_SIZE: u16 = 16;
const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE];
struct Worker {
device_socket: Arc<Socket>,
interrupt: Interrupt,
interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
shutdown: bool,
}
impl Worker {
fn new(
device_socket: Arc<Socket>,
interrupt: Interrupt,
interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
) -> Self {
Self {
device_socket,
interrupt,
interrupt_socket,
shutdown: false,
}
}
fn handle_response(&mut self) {
match self.device_socket.recv() {
Ok(Response::Kill) => {
self.shutdown = true;
}
Ok(msg) => {
panic!("unexpected message received: {:?}", msg);
}
Err(e) => {
panic!("recv failed: {:?}", e);
}
}
}
fn interrupt(&self) {
use InterruptProxyEvent::*;
match self.interrupt_socket.recv() {
Ok(SignalUsedQueue(value)) => self.interrupt.signal_used_queue(value).unwrap(),
Ok(SignalConfigChanged) => self.interrupt.signal_config_changed().unwrap(),
Ok(InterruptResample) => self.interrupt.interrupt_resample().unwrap(),
Err(e) => panic!("recv failed: {}", e),
}
}
fn kill(&self) {
if let Err(e) = self.device_socket.send(Request::Kill) {
error!("failed to send Kill: {}", e);
}
}
fn run(mut self, kill_evt: EventFd) {
#[derive(Debug, PollToken)]
enum Token {
Device,
Interrupt,
Kill,
}
let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
(&*self.device_socket, Token::Device),
(&self.interrupt_socket, Token::Interrupt),
(&kill_evt, Token::Kill),
]) {
Ok(pc) => pc,
Err(e) => {
panic!("failed creating PollContext: {}", e);
}
};
while !self.shutdown {
let events = match poll_ctx.wait() {
Ok(v) => v,
Err(e) => {
panic!("failed polling for events: {}", e);
}
};
for event in &events {
match event.token() {
Token::Device => self.handle_response(),
Token::Interrupt => self.interrupt(),
Token::Kill => self.kill(),
}
}
}
}
}
pub struct Controller {
kill_evt: Option<EventFd>,
worker_thread: Option<thread::JoinHandle<()>>,
socket: Arc<Socket>,
}
impl Controller {
/// Construct a controller, and initialize (but don't activate)
/// the remote device.
pub fn create(
memory_params: MemoryParams,
vm_control_socket: VmMemoryControlRequestSocket,
socket: Socket,
) -> Result<Controller, msg_socket2::Error> {
socket.send(Request::Create {
memory_params,
vm_control_socket: MaybeOwnedFd::Borrowed(vm_control_socket.as_raw_fd()),
})?;
Ok(Controller {
kill_evt: None,
worker_thread: None,
socket: Arc::new(socket),
})
}
}
impl Drop for Controller {
fn drop(&mut self) {
if let Some(kill_evt) = self.kill_evt.take() {
// Ignore the result because there is nothing we can do about it.
let _ = kill_evt.write(1);
}
if let Some(worker_thread) = self.worker_thread.take() {
let _ = worker_thread.join();
}
}
}
impl VirtioDevice for Controller {
fn debug_label(&self) -> String {
if let Err(e) = self.socket.send(Request::DebugLabel) {
return format!("remote virtio (unknown type; {})", e);
}
let label = match self.socket.recv() {
Ok(Response::DebugLabel(label)) => label,
response => panic!("bad response to DebugLabel: {:?}", response),
};
format!("remote {}", label)
}
fn keep_fds(&self) -> Vec<RawFd> {
let mut keep_fds = Vec::new();
if let Some(ref kill_evt) = self.kill_evt {
keep_fds.push(kill_evt.as_raw_fd());
}
keep_fds.push(self.socket.as_raw_fd());
keep_fds
}
fn device_type(&self) -> u32 {
if let Err(e) = self.socket.send(Request::DeviceType) {
panic!("failed to send DeviceType: {}", e);
}
match self.socket.recv() {
Ok(Response::DeviceType(device_type)) => device_type,
response => {
panic!("bad response to Reset: {:?}", response);
}
}
}
fn queue_max_sizes(&self) -> Vec<u16> {
if let Err(e) = self.socket.send(Request::QueueMaxSizes) {
panic!("failed to send QueueMaxSizes: {}", e);
}
match self.socket.recv() {
Ok(Response::QueueMaxSizes(sizes)) => sizes,
response => {
panic!("bad response to QueueMaxSizes: {:?}", response);
}
}
}
fn features(&self) -> u64 {
if let Err(e) = self.socket.send(Request::Features) {
panic!("failed to send Features: {}", e);
}
match self.socket.recv() {
Ok(Response::Features(features)) => features,
response => {
panic!("bad response to Reset: {:?}", response);
}
}
}
fn ack_features(&mut self, value: u64) {
if let Err(e) = self.socket.send(Request::AckFeatures(value)) {
panic!("failed to send AckFeatures: {}", e);
}
}
fn read_config(&self, offset: u64, data: &mut [u8]) {
let len = data.len();
if let Err(e) = self.socket.send(Request::ReadConfig { offset, len }) {
panic!("failed to send ReadConfig: {}", e);
}
match self.socket.recv() {
Ok(Response::ReadConfig(response)) => {
data.copy_from_slice(&response[..len]); // TODO: test no panic
}
response => panic!("bad response to ReadConfig: {:?}", response),
}
}
fn write_config(&mut self, offset: u64, data: &[u8]) {
if let Err(e) = self.socket.send(Request::WriteConfig {
offset,
data: data.to_vec(),
}) {
error!("failed to send WriteConfig: {}", e);
}
}
fn activate(
&mut self,
mem: GuestMemory,
interrupt: Interrupt,
queues: Vec<Queue>,
queue_evts: Vec<EventFd>,
) {
if queues.len() != QUEUE_SIZES.len() || queue_evts.len() != QUEUE_SIZES.len() {
panic!(
"queues ({}) or queue_evts ({}) wrong size",
queues.len(),
queue_evts.len()
);
}
let (self_kill_evt, kill_evt) = match EventFd::new().and_then(|e| Ok((e.try_clone()?, e))) {
Ok(v) => v,
Err(e) => {
panic!("failed creating kill EventFd pair: {}", e);
}
};
self.kill_evt = Some(self_kill_evt);
let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
if let Err(e) = self.socket.send(Request::Activate {
shm: MaybeOwnedFd::new_borrowed(&mem),
interrupt: MaybeOwnedFd::new_borrowed(&theirs),
interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
queues,
queue_evts: queue_evts
.iter()
.map(|e| MaybeOwnedFd::new_borrowed(e))
.collect(),
}) {
panic!("failed to send Activate: {}", e);
}
let socket = Arc::clone(&self.socket);
let worker_result = thread::Builder::new()
.name("virtio_wl".to_string())
.spawn(move || {
Worker::new(socket, interrupt, MsgSocket::new(ours)).run(kill_evt);
});
match worker_result {
Err(e) => {
panic!("failed to spawn virtio_wl worker: {}", e);
}
Ok(join_handle) => {
self.worker_thread = Some(join_handle);
}
}
}
fn reset(&mut self) -> bool {
if let Err(e) = self.socket.send(Request::Reset) {
error!("failed to send Reset: {}", e);
return false;
}
match self.socket.recv() {
Ok(Response::Reset(result)) => result,
response => {
error!("bad response to Reset: {:?}", response);
false
}
}
}
fn get_device_bars(&mut self, address: PciAddress) -> Vec<PciBarConfiguration> {
if let Err(e) = self.socket.send(Request::GetDeviceBars(address)) {
panic!("failed to send GetDeviceBars: {}", e);
}
match self.socket.recv() {
Ok(Response::GetDeviceBars(bars)) => bars,
response => {
panic!("bad response to GetDeviceBars: {:?}", response);
}
}
}
fn get_device_caps(&self) -> Vec<Box<dyn PciCapability>> {
if let Err(e) = self.socket.send(Request::GetDeviceCaps) {
panic!("failed to send GetDeviceCaps: {}", e);
}
match self.socket.recv() {
Ok(Response::GetDeviceCaps(caps)) => caps
.into_iter()
.map(|cap| Box::new(cap) as Box<dyn PciCapability>)
.collect(),
response => {
panic!("bad response to GetDeviceCaps: {:?}", response);
}
}
}
}