summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-13 23:10:27 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:36:39 +0000
commit94f83d7bafb86770555457d7a11f9b2b2cb7166c (patch)
tree87503a019cb53d740ce629e490b3d9d615c52e1b
parent135afeb7e4934d744542dad725c5996d6bcd70dc (diff)
downloadcrosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar.gz
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar.bz2
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar.lz
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar.xz
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.tar.zst
crosvm-94f83d7bafb86770555457d7a11f9b2b2cb7166c.zip
Don't give worker ownership of socket
-rw-r--r--devices/src/virtio/controller.rs75
1 files changed, 36 insertions, 39 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 5d5d76a..a1656c6 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -31,6 +31,7 @@
 use std::collections::BTreeMap as Map;
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::path::PathBuf;
+use std::sync::Arc;
 use std::thread;
 
 use super::resource_bridge::*;
@@ -115,7 +116,7 @@ const QUEUE_SIZE: u16 = 16;
 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE];
 
 struct Worker {
-    device_socket: Socket,
+    device_socket: Arc<Socket>,
     interrupt: Interrupt,
     interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
 
@@ -124,7 +125,7 @@ struct Worker {
 
 impl Worker {
     fn new(
-        device_socket: Socket,
+        device_socket: Arc<Socket>,
         interrupt: Interrupt,
         interrupt_socket: MsgSocket<(), InterruptProxyEvent>,
     ) -> Self {
@@ -182,7 +183,7 @@ impl Worker {
         }
 
         let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
-            (&self.device_socket, Token::Device),
+            (&*self.device_socket, Token::Device),
             (&self.interrupt_socket, Token::Interrupt),
             (&kill_evt, Token::Kill),
         ]) {
@@ -217,7 +218,7 @@ pub struct Controller {
     kill_evt: Option<EventFd>,
     worker_thread: Option<thread::JoinHandle<()>>,
     use_transition_flags: bool,
-    socket: Option<Socket>,
+    socket: Arc<Socket>,
 }
 
 impl Controller {
@@ -239,7 +240,7 @@ impl Controller {
             kill_evt: None,
             worker_thread: None,
             use_transition_flags: false,
-            socket: Some(socket),
+            socket: Arc::new(socket),
         })
     }
 }
@@ -265,9 +266,7 @@ impl VirtioDevice for Controller {
             keep_fds.push(kill_evt.as_raw_fd());
         }
 
-        if let Some(ref socket) = self.socket {
-            keep_fds.push(socket.as_raw_fd());
-        }
+        keep_fds.push(self.socket.as_raw_fd());
 
         keep_fds
     }
@@ -310,39 +309,37 @@ impl VirtioDevice for Controller {
         };
         self.kill_evt = Some(self_kill_evt);
 
-        if let Some(socket) = self.socket.take() {
-            let use_transition_flags = self.use_transition_flags;
-
-            let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
-
-            if let Err(e) = socket.send(MsgOnSocketRequest::Activate {
-                shm: MaybeOwnedFd::new_borrowed(&mem),
-                interrupt: MaybeOwnedFd::new_borrowed(&theirs),
-                interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
-                in_queue: queues.remove(0),
-                out_queue: queues.remove(0),
-                in_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[0]),
-                out_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[1]),
-            }) {
-                error!("failed to send Activate: {}", e);
-                return;
-            }
+        let use_transition_flags = self.use_transition_flags;
 
-            let worker_result =
-                thread::Builder::new()
-                    .name("virtio_wl".to_string())
-                    .spawn(move || {
-                        Worker::new(socket, interrupt, MsgSocket::new(ours)).run(kill_evt);
-                    });
+        let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
 
-            match worker_result {
-                Err(e) => {
-                    error!("failed to spawn virtio_wl worker: {}", e);
-                    return;
-                }
-                Ok(join_handle) => {
-                    self.worker_thread = Some(join_handle);
-                }
+        if let Err(e) = self.socket.send(MsgOnSocketRequest::Activate {
+            shm: MaybeOwnedFd::new_borrowed(&mem),
+            interrupt: MaybeOwnedFd::new_borrowed(&theirs),
+            interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
+            in_queue: queues.remove(0),
+            out_queue: queues.remove(0),
+            in_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[0]),
+            out_queue_evt: MaybeOwnedFd::new_borrowed(&queue_evts[1]),
+        }) {
+            error!("failed to send Activate: {}", e);
+            return;
+        }
+
+        let socket = Arc::clone(&self.socket);
+        let worker_result = thread::Builder::new()
+            .name("virtio_wl".to_string())
+            .spawn(move || {
+                Worker::new(socket, interrupt, MsgSocket::new(ours)).run(kill_evt);
+            });
+
+        match worker_result {
+            Err(e) => {
+                error!("failed to spawn virtio_wl worker: {}", e);
+                return;
+            }
+            Ok(join_handle) => {
+                self.worker_thread = Some(join_handle);
             }
         }
     }