summary refs log tree commit diff
path: root/devices/src/virtio/vhost/worker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/vhost/worker.rs')
-rw-r--r--devices/src/virtio/vhost/worker.rs135
1 files changed, 127 insertions, 8 deletions
diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs
index 1eff01f..ca02a63 100644
--- a/devices/src/virtio/vhost/worker.rs
+++ b/devices/src/virtio/vhost/worker.rs
@@ -4,16 +4,16 @@
 
 use std::os::raw::c_ulonglong;
 
-use sys_util::{EventFd, PollContext, PollToken};
+use sys_util::{error, Error as SysError, EventFd, PollContext, PollToken};
 use vhost::Vhost;
 
+use super::control_socket::{VhostDevRequest, VhostDevResponse, VhostDevResponseSocket};
 use super::{Error, Result};
 use crate::virtio::{Interrupt, Queue};
+use libc::EIO;
+use msg_socket::{MsgReceiver, MsgSender};
 
-/// Worker that takes care of running the vhost device.  This mainly involves forwarding interrupts
-/// from the vhost driver to the guest VM because crosvm only supports the virtio-mmio transport,
-/// which requires a bit to be set in the interrupt status register before triggering the interrupt
-/// and the vhost driver doesn't do this for us.
+/// Worker that takes care of running the vhost device.
 pub struct Worker<T: Vhost> {
     interrupt: Interrupt,
     queues: Vec<Queue>,
@@ -21,6 +21,7 @@ pub struct Worker<T: Vhost> {
     pub vhost_interrupt: Vec<EventFd>,
     acked_features: u64,
     pub kill_evt: EventFd,
+    pub response_socket: Option<VhostDevResponseSocket>,
 }
 
 impl<T: Vhost> Worker<T> {
@@ -31,6 +32,7 @@ impl<T: Vhost> Worker<T> {
         interrupt: Interrupt,
         acked_features: u64,
         kill_evt: EventFd,
+        response_socket: Option<VhostDevResponseSocket>,
     ) -> Worker<T> {
         Worker {
             interrupt,
@@ -39,6 +41,7 @@ impl<T: Vhost> Worker<T> {
             vhost_interrupt,
             acked_features,
             kill_evt,
+            response_socket,
         }
     }
 
@@ -87,9 +90,7 @@ impl<T: Vhost> Worker<T> {
             self.vhost_handle
                 .set_vring_base(queue_index, 0)
                 .map_err(Error::VhostSetVringBase)?;
-            self.vhost_handle
-                .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
-                .map_err(Error::VhostSetVringCall)?;
+            self.set_vring_call_for_entry(queue_index, queue.vector as usize)?;
             self.vhost_handle
                 .set_vring_kick(queue_index, &queue_evts[queue_index])
                 .map_err(Error::VhostSetVringKick)?;
@@ -102,6 +103,7 @@ impl<T: Vhost> Worker<T> {
             VhostIrqi { index: usize },
             InterruptResample,
             Kill,
+            ControlNotify,
         }
 
         let poll_ctx: PollContext<Token> = PollContext::build_with(&[
@@ -115,6 +117,11 @@ impl<T: Vhost> Worker<T> {
                 .add(vhost_int, Token::VhostIrqi { index })
                 .map_err(Error::CreatePollContext)?;
         }
+        if let Some(socket) = &self.response_socket {
+            poll_ctx
+                .add(socket, Token::ControlNotify)
+                .map_err(Error::CreatePollContext)?;
+        }
 
         'poll: loop {
             let events = poll_ctx.wait().map_err(Error::PollError)?;
@@ -134,10 +141,122 @@ impl<T: Vhost> Worker<T> {
                         let _ = self.kill_evt.read();
                         break 'poll;
                     }
+                    Token::ControlNotify => {
+                        if let Some(socket) = &self.response_socket {
+                            match socket.recv() {
+                                Ok(VhostDevRequest::MsixEntryChanged(index)) => {
+                                    let mut qindex = 0;
+                                    for (queue_index, queue) in self.queues.iter().enumerate() {
+                                        if queue.vector == index as u16 {
+                                            qindex = queue_index;
+                                            break;
+                                        }
+                                    }
+                                    let response =
+                                        match self.set_vring_call_for_entry(qindex, index) {
+                                            Ok(()) => VhostDevResponse::Ok,
+                                            Err(e) => {
+                                                error!(
+                                                "Set vring call failed for masked entry {}: {:?}",
+                                                index, e
+                                            );
+                                                VhostDevResponse::Err(SysError::new(EIO))
+                                            }
+                                        };
+                                    if let Err(e) = socket.send(&response) {
+                                        error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
+                                    }
+                                }
+                                Ok(VhostDevRequest::MsixChanged) => {
+                                    let response = match self.set_vring_calls() {
+                                        Ok(()) => VhostDevResponse::Ok,
+                                        Err(e) => {
+                                            error!("Set vring calls failed: {:?}", e);
+                                            VhostDevResponse::Err(SysError::new(EIO))
+                                        }
+                                    };
+                                    if let Err(e) = socket.send(&response) {
+                                        error!(
+                                            "Vhost failed to send VhostMsixMasked Response: {:?}",
+                                            e
+                                        );
+                                    }
+                                }
+                                Err(e) => {
+                                    error!("Vhost failed to receive Control request: {:?}", e);
+                                }
+                            }
+                        }
+                    }
                 }
             }
         }
         cleanup_vqs(&self.vhost_handle)?;
         Ok(())
     }
+
+    fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
+        // No response_socket means it doesn't have any control related
+        // with the msix. Due to this, cannot use the direct irq fd but
+        // should fall back to indirect irq fd.
+        if self.response_socket.is_some() {
+            if let Some(msix_config) = &self.interrupt.msix_config {
+                let msix_config = msix_config.lock();
+                let msix_masked = msix_config.masked();
+                if msix_masked {
+                    return Ok(());
+                }
+                if !msix_config.table_masked(vector) {
+                    if let Some(irqfd) = msix_config.get_irqfd(vector) {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, irqfd)
+                            .map_err(Error::VhostSetVringCall)?;
+                    } else {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                            .map_err(Error::VhostSetVringCall)?;
+                    }
+                    return Ok(());
+                }
+            }
+        }
+
+        self.vhost_handle
+            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+            .map_err(Error::VhostSetVringCall)?;
+        Ok(())
+    }
+
+    fn set_vring_calls(&self) -> Result<()> {
+        if let Some(msix_config) = &self.interrupt.msix_config {
+            let msix_config = msix_config.lock();
+            if msix_config.masked() {
+                for (queue_index, _) in self.queues.iter().enumerate() {
+                    self.vhost_handle
+                        .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                        .map_err(Error::VhostSetVringCall)?;
+                }
+            } else {
+                for (queue_index, queue) in self.queues.iter().enumerate() {
+                    let vector = queue.vector as usize;
+                    if !msix_config.table_masked(vector) {
+                        if let Some(irqfd) = msix_config.get_irqfd(vector) {
+                            self.vhost_handle
+                                .set_vring_call(queue_index, irqfd)
+                                .map_err(Error::VhostSetVringCall)?;
+                        } else {
+                            self.vhost_handle
+                                .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                                .map_err(Error::VhostSetVringCall)?;
+                        }
+                    } else {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                            .map_err(Error::VhostSetVringCall)?;
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
 }