summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--devices/src/virtio/balloon.rs26
-rw-r--r--devices/src/virtio/block.rs44
-rw-r--r--devices/src/virtio/net.rs47
-rw-r--r--devices/src/virtio/rng.rs26
4 files changed, 120 insertions, 23 deletions
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs
index 2d50c3d..29b4ec3 100644
--- a/devices/src/virtio/balloon.rs
+++ b/devices/src/virtio/balloon.rs
@@ -238,7 +238,7 @@ pub struct Balloon {
     config: Arc<BalloonConfig>,
     features: u64,
     kill_evt: Option<EventFd>,
-    worker_thread: Option<thread::JoinHandle<()>>,
+    worker_thread: Option<thread::JoinHandle<Worker>>,
 }
 
 impl Balloon {
@@ -349,6 +349,7 @@ impl VirtioDevice for Balloon {
                     config,
                 };
                 worker.run(queue_evts, kill_evt);
+                worker
             });
 
         match worker_result {
@@ -360,4 +361,27 @@ impl VirtioDevice for Balloon {
             }
         }
     }
+
+    fn reset(&mut self) -> bool {
+        if let Some(kill_evt) = self.kill_evt.take() {
+            if kill_evt.write(1).is_err() {
+                error!("{}: failed to notify the kill event", self.debug_label());
+                return false;
+            }
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            match worker_thread.join() {
+                Err(_) => {
+                    error!("{}: failed to get back resources", self.debug_label());
+                    return false;
+                }
+                Ok(worker) => {
+                    self.command_socket = Some(worker.command_socket);
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
 }
diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs
index dd37f09..51a7779 100644
--- a/devices/src/virtio/block.rs
+++ b/devices/src/virtio/block.rs
@@ -242,6 +242,7 @@ struct Worker {
     disk_size: Arc<Mutex<u64>>,
     read_only: bool,
     sparse: bool,
+    control_socket: DiskControlResponseSocket,
 }
 
 impl Worker {
@@ -350,12 +351,7 @@ impl Worker {
         DiskControlResult::Ok
     }
 
-    fn run(
-        &mut self,
-        queue_evt: EventFd,
-        kill_evt: EventFd,
-        control_socket: DiskControlResponseSocket,
-    ) {
+    fn run(&mut self, queue_evt: EventFd, kill_evt: EventFd) {
         #[derive(PollToken)]
         enum Token {
             FlushTimer,
@@ -377,7 +373,7 @@ impl Worker {
         let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
             (&flush_timer, Token::FlushTimer),
             (&queue_evt, Token::QueueAvailable),
-            (&control_socket, Token::ControlRequest),
+            (&self.control_socket, Token::ControlRequest),
             (self.interrupt.get_resample_evt(), Token::InterruptResample),
             (&kill_evt, Token::Kill),
         ]) {
@@ -420,7 +416,7 @@ impl Worker {
                         }
                     }
                     Token::ControlRequest => {
-                        let req = match control_socket.recv() {
+                        let req = match self.control_socket.recv() {
                             Ok(req) => req,
                             Err(e) => {
                                 error!("control socket failed recv: {}", e);
@@ -435,7 +431,7 @@ impl Worker {
                             }
                         };
 
-                        if let Err(e) = control_socket.send(&resp) {
+                        if let Err(e) = self.control_socket.send(&resp) {
                             error!("control socket failed send: {}", e);
                             break 'poll;
                         }
@@ -456,7 +452,7 @@ impl Worker {
 /// Virtio device for exposing block level read/write operations on a host file.
 pub struct Block {
     kill_evt: Option<EventFd>,
-    worker_thread: Option<thread::JoinHandle<()>>,
+    worker_thread: Option<thread::JoinHandle<Worker>>,
     disk_image: Option<Box<dyn DiskFile>>,
     disk_size: Arc<Mutex<u64>>,
     avail_features: u64,
@@ -768,8 +764,10 @@ impl VirtioDevice for Block {
                                 disk_size,
                                 read_only,
                                 sparse,
+                                control_socket,
                             };
-                            worker.run(queue_evts.remove(0), kill_evt, control_socket);
+                            worker.run(queue_evts.remove(0), kill_evt);
+                            worker
                         });
 
                 match worker_result {
@@ -784,6 +782,30 @@ impl VirtioDevice for Block {
             }
         }
     }
+
+    fn reset(&mut self) -> bool {
+        if let Some(kill_evt) = self.kill_evt.take() {
+            if kill_evt.write(1).is_err() {
+                error!("{}: failed to notify the kill event", self.debug_label());
+                return false;
+            }
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            match worker_thread.join() {
+                Err(_) => {
+                    error!("{}: failed to get back resources", self.debug_label());
+                    return false;
+                }
+                Ok(worker) => {
+                    self.disk_image = Some(worker.disk_image);
+                    self.control_socket = Some(worker.control_socket);
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
 }
 
 #[cfg(test)]
diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs
index a2b1058..bacedd7 100644
--- a/devices/src/virtio/net.rs
+++ b/devices/src/virtio/net.rs
@@ -97,6 +97,7 @@ struct Worker<T: TapT> {
     // Remove once MRG_RXBUF is supported and this variable is actually used.
     #[allow(dead_code)]
     acked_features: u64,
+    kill_evt: EventFd,
 }
 
 impl<T> Worker<T>
@@ -192,12 +193,7 @@ where
         self.interrupt.signal_used_queue(self.tx_queue.vector);
     }
 
-    fn run(
-        &mut self,
-        rx_queue_evt: EventFd,
-        tx_queue_evt: EventFd,
-        kill_evt: EventFd,
-    ) -> Result<(), NetError> {
+    fn run(&mut self, rx_queue_evt: EventFd, tx_queue_evt: EventFd) -> Result<(), NetError> {
         #[derive(PollToken)]
         enum Token {
             // A frame is available for reading from the tap device to receive in the guest.
@@ -217,7 +213,7 @@ where
             (&rx_queue_evt, Token::RxQueue),
             (&tx_queue_evt, Token::TxQueue),
             (self.interrupt.get_resample_evt(), Token::InterruptResample),
-            (&kill_evt, Token::Kill),
+            (&self.kill_evt, Token::Kill),
         ])
         .map_err(NetError::CreatePollContext)?;
 
@@ -258,7 +254,10 @@ where
                     Token::InterruptResample => {
                         self.interrupt.interrupt_resample();
                     }
-                    Token::Kill => break 'poll,
+                    Token::Kill => {
+                        let _ = self.kill_evt.read();
+                        break 'poll;
+                    }
                 }
             }
         }
@@ -269,7 +268,7 @@ where
 pub struct Net<T: TapT> {
     workers_kill_evt: Option<EventFd>,
     kill_evt: EventFd,
-    worker_thread: Option<thread::JoinHandle<()>>,
+    worker_thread: Option<thread::JoinHandle<Worker<T>>>,
     tap: Option<T>,
     avail_features: u64,
     acked_features: u64,
@@ -398,6 +397,7 @@ where
         if let Some(workers_kill_evt) = &self.workers_kill_evt {
             keep_fds.push(workers_kill_evt.as_raw_fd());
         }
+        keep_fds.push(self.kill_evt.as_raw_fd());
 
         keep_fds
     }
@@ -457,13 +457,15 @@ where
                                 tx_queue,
                                 tap,
                                 acked_features,
+                                kill_evt,
                             };
                             let rx_queue_evt = queue_evts.remove(0);
                             let tx_queue_evt = queue_evts.remove(0);
-                            let result = worker.run(rx_queue_evt, tx_queue_evt, kill_evt);
+                            let result = worker.run(rx_queue_evt, tx_queue_evt);
                             if let Err(e) = result {
                                 error!("net worker thread exited with error: {}", e);
                             }
+                            worker
                         });
 
                 match worker_result {
@@ -478,4 +480,29 @@ where
             }
         }
     }
+
+    fn reset(&mut self) -> bool {
+        // Only kill the child if it claimed its eventfd.
+        if self.workers_kill_evt.is_none() {
+            if self.kill_evt.write(1).is_err() {
+                error!("{}: failed to notify the kill event", self.debug_label());
+                return false;
+            }
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            match worker_thread.join() {
+                Err(_) => {
+                    error!("{}: failed to get back resources", self.debug_label());
+                    return false;
+                }
+                Ok(worker) => {
+                    self.tap = Some(worker.tap);
+                    self.workers_kill_evt = Some(worker.kill_evt);
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
 }
diff --git a/devices/src/virtio/rng.rs b/devices/src/virtio/rng.rs
index a7fc3d7..b8e4bb1 100644
--- a/devices/src/virtio/rng.rs
+++ b/devices/src/virtio/rng.rs
@@ -121,7 +121,7 @@ impl Worker {
 /// Virtio device for exposing entropy to the guest OS through virtio.
 pub struct Rng {
     kill_evt: Option<EventFd>,
-    worker_thread: Option<thread::JoinHandle<()>>,
+    worker_thread: Option<thread::JoinHandle<Worker>>,
     random_file: Option<File>,
 }
 
@@ -203,6 +203,7 @@ impl VirtioDevice for Rng {
                             random_file,
                         };
                         worker.run(queue_evts.remove(0), kill_evt);
+                        worker
                     });
 
             match worker_result {
@@ -216,4 +217,27 @@ impl VirtioDevice for Rng {
             }
         }
     }
+
+    fn reset(&mut self) -> bool {
+        if let Some(kill_evt) = self.kill_evt.take() {
+            if kill_evt.write(1).is_err() {
+                error!("{}: failed to notify the kill event", self.debug_label());
+                return false;
+            }
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            match worker_thread.join() {
+                Err(_) => {
+                    error!("{}: failed to get back resources", self.debug_label());
+                    return false;
+                }
+                Ok(worker) => {
+                    self.random_file = Some(worker.random_file);
+                    return true;
+                }
+            }
+        }
+        return false;
+    }
 }