summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--devices/src/virtio/vhost/net.rs29
-rw-r--r--devices/src/virtio/vhost/worker.rs11
2 files changed, 35 insertions, 5 deletions
diff --git a/devices/src/virtio/vhost/net.rs b/devices/src/virtio/vhost/net.rs
index e9dcf14..ff72970 100644
--- a/devices/src/virtio/vhost/net.rs
+++ b/devices/src/virtio/vhost/net.rs
@@ -25,7 +25,7 @@ const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
 pub struct Net<T: TapT, U: VhostNetT<T>> {
     workers_kill_evt: Option<EventFd>,
     kill_evt: EventFd,
-    worker_thread: Option<thread::JoinHandle<()>>,
+    worker_thread: Option<thread::JoinHandle<(Worker<U>, T)>>,
     tap: Option<T>,
     vhost_net_handle: Option<U>,
     vhost_interrupt: Option<Vec<EventFd>>,
@@ -141,6 +141,7 @@ where
         if let Some(workers_kill_evt) = &self.workers_kill_evt {
             keep_fds.push(workers_kill_evt.as_raw_fd());
         }
+        keep_fds.push(self.kill_evt.as_raw_fd());
 
         keep_fds
     }
@@ -220,6 +221,7 @@ where
                                 if let Err(e) = result {
                                     error!("net worker thread exited with error: {}", e);
                                 }
+                                (worker, tap)
                             });
 
                         match worker_result {
@@ -248,6 +250,31 @@ where
             }
         }
     }
+
+    fn reset(&mut self) -> bool {
+        // Only kill the child if it claimed its eventfd.
+        if self.workers_kill_evt.is_none() && self.kill_evt.write(1).is_err() {
+            error!("{}: failed to notify the kill event", self.debug_label());
+            return false;
+        }
+
+        if let Some(worker_thread) = self.worker_thread.take() {
+            match worker_thread.join() {
+                Err(_) => {
+                    error!("{}: failed to get back resources", self.debug_label());
+                    return false;
+                }
+                Ok((worker, tap)) => {
+                    self.vhost_net_handle = Some(worker.vhost_handle);
+                    self.tap = Some(tap);
+                    self.vhost_interrupt = Some(worker.vhost_interrupt);
+                    self.workers_kill_evt = Some(worker.kill_evt);
+                    return true;
+                }
+            }
+        }
+        false
+    }
 }
 
 #[cfg(test)]
diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs
index 630edaa..1eff01f 100644
--- a/devices/src/virtio/vhost/worker.rs
+++ b/devices/src/virtio/vhost/worker.rs
@@ -17,10 +17,10 @@ use crate::virtio::{Interrupt, Queue};
 pub struct Worker<T: Vhost> {
     interrupt: Interrupt,
     queues: Vec<Queue>,
-    vhost_handle: T,
-    vhost_interrupt: Vec<EventFd>,
+    pub vhost_handle: T,
+    pub vhost_interrupt: Vec<EventFd>,
     acked_features: u64,
-    kill_evt: EventFd,
+    pub kill_evt: EventFd,
 }
 
 impl<T: Vhost> Worker<T> {
@@ -130,7 +130,10 @@ impl<T: Vhost> Worker<T> {
                     Token::InterruptResample => {
                         self.interrupt.interrupt_resample();
                     }
-                    Token::Kill => break 'poll,
+                    Token::Kill => {
+                        let _ = self.kill_evt.read();
+                        break 'poll;
+                    }
                 }
             }
         }