summary refs log tree commit diff
path: root/devices
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-12 01:04:03 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:36:24 +0000
commita24b1e1c1d287b5de7946a6225f2ef9326c433b3 (patch)
tree01ca117855e908c148fe2d8ecce09a6aab81dc93 /devices
parent9ecffa4880b741d0de23c6d0ee4755bd66db01fb (diff)
downloadcrosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar.gz
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar.bz2
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar.lz
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar.xz
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.tar.zst
crosvm-a24b1e1c1d287b5de7946a6225f2ef9326c433b3.zip
forward kill to/from wl
Diffstat (limited to 'devices')
-rw-r--r--devices/src/virtio/controller.rs202
-rw-r--r--devices/src/virtio/wl.rs1
2 files changed, 140 insertions, 63 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index c29112d..83834e4 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -34,7 +34,7 @@ use std::path::PathBuf;
 use std::thread;
 
 use msg_socket::{MsgReceiver, MsgSender};
-use sys_util::{error, EventFd, GuestMemory, Result, SharedMemory};
+use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, Result, SharedMemory};
 
 use super::resource_bridge::*;
 use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice, TYPE_WL, VIRTIO_F_VERSION_1};
@@ -44,18 +44,26 @@ use msg_socket::{MsgOnSocket, MsgSocket};
 use sys_util::net::UnixSeqpacket;
 
 #[derive(Debug, MsgOnSocket)]
-pub struct Activate {
-    pub shm: MaybeOwnedFd<SharedMemory>,
-    pub interrupt: MaybeOwnedFd<UnixSeqpacket>,
-    pub interrupt_resample_evt: MaybeOwnedFd<EventFd>,
-    pub in_queue: Queue,
-    pub out_queue: Queue,
-    pub vm_socket: MaybeOwnedFd<UnixSeqpacket>,
-    pub in_queue_evt: MaybeOwnedFd<EventFd>,
-    pub out_queue_evt: MaybeOwnedFd<EventFd>,
+pub enum Request {
+    Activate {
+        shm: MaybeOwnedFd<SharedMemory>,
+        interrupt: MaybeOwnedFd<UnixSeqpacket>,
+        interrupt_resample_evt: MaybeOwnedFd<EventFd>,
+        in_queue: Queue,
+        out_queue: Queue,
+        vm_socket: MaybeOwnedFd<UnixSeqpacket>,
+        in_queue_evt: MaybeOwnedFd<EventFd>,
+        out_queue_evt: MaybeOwnedFd<EventFd>,
+    },
+    Kill,
 }
 
-type Socket = MsgSocket<Activate, ()>;
+#[derive(Debug, MsgOnSocket)]
+pub enum Response {
+    Kill,
+}
+
+type Socket = MsgSocket<Request, Response>;
 
 const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
 
@@ -63,30 +71,94 @@ const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
 const QUEUE_SIZE: u16 = 16;
 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE];
 
-struct InterruptWorker {
-    socket: MsgSocket<(), InterruptProxyEvent>,
+struct Worker {
+    device_socket: Socket,
     interrupt: Interrupt,
+    interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
+
+    shutdown: bool,
 }
 
-impl InterruptWorker {
-    fn new(socket: MsgSocket<(), InterruptProxyEvent>, interrupt: Interrupt) -> Self {
-        Self { socket, interrupt }
+impl Worker {
+    fn new(
+        device_socket: Socket,
+        interrupt: Interrupt,
+        interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
+    ) -> Self {
+        Self {
+            device_socket,
+            interrupt,
+            interrupt_socket,
+            shutdown: false,
+        }
+    }
+
+    fn handle_response(&mut self) {
+        match self.device_socket.recv() {
+            Ok(Response::Kill) => {
+                self.shutdown = true;
+            }
+
+            Err(e) => {
+                error!("recv failed: {:?}", e);
+            }
+        }
+    }
+
+    fn interrupt(&self) {
+        use InterruptProxyEvent::*;
+        match self.interrupt_socket.recv() {
+            Ok(SignalUsedQueue(value)) => self.interrupt.signal_used_queue(value).unwrap(),
+            Ok(SignalConfigChanged) => self.interrupt.signal_config_changed().unwrap(),
+            Ok(InterruptResample) => self.interrupt.interrupt_resample().unwrap(),
+
+            Err(e) => {
+                eprintln!("recv error: {}", e);
+                panic!("recv error: {}", e)
+            }
+        }
+    }
+
+    fn kill(&self) {
+        if let Err(e) = self.device_socket.send(&Request::Kill) {
+            error!("failed to send Kill message: {}", e);
+        }
     }
 
-    fn run(&self) {
-        // TODO: handle Kill
+    fn run(mut self, kill_evt: EventFd) {
+        #[derive(Debug, PollToken)]
+        enum Token {
+            Device,
+            Interrupt,
+            Kill,
+        }
 
-        loop {
-            use InterruptProxyEvent::*;
-            let val = self.socket.recv();
-            match val {
-                Ok(SignalUsedQueue(value)) => self.interrupt.signal_used_queue(value).unwrap(),
-                Ok(SignalConfigChanged) => self.interrupt.signal_config_changed().unwrap(),
-                Ok(InterruptResample) => self.interrupt.interrupt_resample().unwrap(),
+        let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
+            (&self.device_socket, Token::Device),
+            (&self.interrupt_socket, Token::Interrupt),
+            (&kill_evt, Token::Kill),
+        ]) {
+            Ok(pc) => pc,
+            Err(e) => {
+                error!("failed creating PollContext: {}", e);
+                return;
+            }
+        };
 
+        while !self.shutdown {
+            let events = match poll_ctx.wait() {
+                Ok(v) => v,
                 Err(e) => {
-                    eprintln!("recv error: {}", e);
-                    panic!("recv error: {}", e)
+                    error!("failed polling for events: {}", e);
+                    break;
+                }
+            };
+
+            for event in &events {
+                match event.token() {
+                    Token::Device => self.handle_response(),
+                    Token::Interrupt => self.interrupt(),
+                    Token::Kill => self.kill(),
                 }
             }
         }
@@ -100,7 +172,7 @@ pub struct Controller {
     vm_socket: Option<VmMemoryControlRequestSocket>,
     resource_bridge: Option<ResourceRequestSocket>,
     use_transition_flags: bool,
-    socket: Socket,
+    socket: Option<Socket>,
 }
 
 impl Controller {
@@ -117,7 +189,7 @@ impl Controller {
             vm_socket: Some(vm_socket),
             resource_bridge,
             use_transition_flags: false,
-            socket,
+            socket: Some(socket),
         })
     }
 }
@@ -150,7 +222,9 @@ impl VirtioDevice for Controller {
             keep_fds.push(kill_evt.as_raw_fd());
         }
 
-        keep_fds.push(self.socket.as_raw_fd());
+        if let Some(ref socket) = self.socket {
+            keep_fds.push(socket.as_raw_fd());
+        }
 
         keep_fds
     }
@@ -194,40 +268,44 @@ impl VirtioDevice for Controller {
         self.kill_evt = Some(self_kill_evt);
 
         if let Some(vm_socket) = self.vm_socket.take() {
-            let wayland_paths = self.wayland_paths.clone();
-            let use_transition_flags = self.use_transition_flags;
-            let resource_bridge = self.resource_bridge.take();
-
-            let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
-
-            if let Err(e) = self.socket.send(&Activate {
-                shm: MaybeOwnedFd::new_borrowed(&mem),
-                interrupt: MaybeOwnedFd::new_borrowed(&theirs),
-                interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
-                in_queue: queues.remove(0),
-                out_queue: queues.remove(0),
-                vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket),
-                in_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[0]),
-                out_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[1]),
-            }) {
-                error!("failed to send Activate: {}", e);
-                return;
-            }
-
-            let worker_result =
-                thread::Builder::new()
-                    .name("virtio_wl".to_string())
-                    .spawn(move || {
-                        InterruptWorker::new(MsgSocket::new(ours), interrupt).run();
-                    });
-
-            match worker_result {
-                Err(e) => {
-                    error!("failed to spawn virtio_wl worker: {}", e);
+            if let Some(socket) = self.socket.take() {
+                let wayland_paths = self.wayland_paths.clone();
+                let use_transition_flags = self.use_transition_flags;
+                let resource_bridge = self.resource_bridge.take();
+
+                let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
+
+                if let Err(e) = socket.send(&Request::Activate {
+                    shm: MaybeOwnedFd::new_borrowed(&mem),
+                    interrupt: MaybeOwnedFd::new_borrowed(&theirs),
+                    interrupt_resample_evt: MaybeOwnedFd::new_borrowed(
+                        interrupt.get_resample_evt(),
+                    ),
+                    in_queue: queues.remove(0),
+                    out_queue: queues.remove(0),
+                    vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket),
+                    in_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[0]),
+                    out_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[1]),
+                }) {
+                    error!("failed to send Activate: {}", e);
                     return;
                 }
-                Ok(join_handle) => {
-                    self.worker_thread = Some(join_handle);
+
+                let worker_result =
+                    thread::Builder::new()
+                        .name("virtio_wl".to_string())
+                        .spawn(move || {
+                            Worker::new(socket, interrupt, MsgSocket::new(ours)).run(kill_evt);
+                        });
+
+                match worker_result {
+                    Err(e) => {
+                        error!("failed to spawn virtio_wl worker: {}", e);
+                        return;
+                    }
+                    Ok(join_handle) => {
+                        self.worker_thread = Some(join_handle);
+                    }
                 }
             }
         }
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 1e7ea17..bc65039 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -1409,7 +1409,6 @@ impl Worker {
             };
 
             for event in &events {
-                dbg!(event.token());
                 match event.token() {
                     Token::InQueue => {
                         let _ = in_queue_evt.read();