summary refs log tree commit diff
path: root/devices/src/virtio/vhost/worker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/vhost/worker.rs')
-rw-r--r--devices/src/virtio/vhost/worker.rs133
1 files changed, 124 insertions, 9 deletions
diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs
index 03c1066..ca02a63 100644
--- a/devices/src/virtio/vhost/worker.rs
+++ b/devices/src/virtio/vhost/worker.rs
@@ -4,17 +4,16 @@
 
 use std::os::raw::c_ulonglong;
 
-use sys_util::{EventFd, PollContext, PollToken};
+use sys_util::{error, Error as SysError, EventFd, PollContext, PollToken};
 use vhost::Vhost;
 
-use super::control_socket::VhostDevResponseSocket;
+use super::control_socket::{VhostDevRequest, VhostDevResponse, VhostDevResponseSocket};
 use super::{Error, Result};
 use crate::virtio::{Interrupt, Queue};
+use libc::EIO;
+use msg_socket::{MsgReceiver, MsgSender};
 
-/// Worker that takes care of running the vhost device.  This mainly involves forwarding interrupts
-/// from the vhost driver to the guest VM because crosvm only supports the virtio-mmio transport,
-/// which requires a bit to be set in the interrupt status register before triggering the interrupt
-/// and the vhost driver doesn't do this for us.
+/// Worker that takes care of running the vhost device.
 pub struct Worker<T: Vhost> {
     interrupt: Interrupt,
     queues: Vec<Queue>,
@@ -91,9 +90,7 @@ impl<T: Vhost> Worker<T> {
             self.vhost_handle
                 .set_vring_base(queue_index, 0)
                 .map_err(Error::VhostSetVringBase)?;
-            self.vhost_handle
-                .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
-                .map_err(Error::VhostSetVringCall)?;
+            self.set_vring_call_for_entry(queue_index, queue.vector as usize)?;
             self.vhost_handle
                 .set_vring_kick(queue_index, &queue_evts[queue_index])
                 .map_err(Error::VhostSetVringKick)?;
@@ -106,6 +103,7 @@ impl<T: Vhost> Worker<T> {
             VhostIrqi { index: usize },
             InterruptResample,
             Kill,
+            ControlNotify,
         }
 
         let poll_ctx: PollContext<Token> = PollContext::build_with(&[
@@ -119,6 +117,11 @@ impl<T: Vhost> Worker<T> {
                 .add(vhost_int, Token::VhostIrqi { index })
                 .map_err(Error::CreatePollContext)?;
         }
+        if let Some(socket) = &self.response_socket {
+            poll_ctx
+                .add(socket, Token::ControlNotify)
+                .map_err(Error::CreatePollContext)?;
+        }
 
         'poll: loop {
             let events = poll_ctx.wait().map_err(Error::PollError)?;
@@ -138,10 +141,122 @@ impl<T: Vhost> Worker<T> {
                         let _ = self.kill_evt.read();
                         break 'poll;
                     }
+                    Token::ControlNotify => {
+                        if let Some(socket) = &self.response_socket {
+                            match socket.recv() {
+                                Ok(VhostDevRequest::MsixEntryChanged(index)) => {
+                                    let mut qindex = 0;
+                                    for (queue_index, queue) in self.queues.iter().enumerate() {
+                                        if queue.vector == index as u16 {
+                                            qindex = queue_index;
+                                            break;
+                                        }
+                                    }
+                                    let response =
+                                        match self.set_vring_call_for_entry(qindex, index) {
+                                            Ok(()) => VhostDevResponse::Ok,
+                                            Err(e) => {
+                                                error!(
+                                                "Set vring call failed for masked entry {}: {:?}",
+                                                index, e
+                                            );
+                                                VhostDevResponse::Err(SysError::new(EIO))
+                                            }
+                                        };
+                                    if let Err(e) = socket.send(&response) {
+                                        error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
+                                    }
+                                }
+                                Ok(VhostDevRequest::MsixChanged) => {
+                                    let response = match self.set_vring_calls() {
+                                        Ok(()) => VhostDevResponse::Ok,
+                                        Err(e) => {
+                                            error!("Set vring calls failed: {:?}", e);
+                                            VhostDevResponse::Err(SysError::new(EIO))
+                                        }
+                                    };
+                                    if let Err(e) = socket.send(&response) {
+                                        error!(
+                                            "Vhost failed to send VhostMsixMasked Response: {:?}",
+                                            e
+                                        );
+                                    }
+                                }
+                                Err(e) => {
+                                    error!("Vhost failed to receive Control request: {:?}", e);
+                                }
+                            }
+                        }
+                    }
                 }
             }
         }
         cleanup_vqs(&self.vhost_handle)?;
         Ok(())
     }
+
+    fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
+        // No response_socket means it doesn't have any control related
+        // with the msix. Due to this, cannot use the direct irq fd but
+        // should fall back to indirect irq fd.
+        if self.response_socket.is_some() {
+            if let Some(msix_config) = &self.interrupt.msix_config {
+                let msix_config = msix_config.lock();
+                let msix_masked = msix_config.masked();
+                if msix_masked {
+                    return Ok(());
+                }
+                if !msix_config.table_masked(vector) {
+                    if let Some(irqfd) = msix_config.get_irqfd(vector) {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, irqfd)
+                            .map_err(Error::VhostSetVringCall)?;
+                    } else {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                            .map_err(Error::VhostSetVringCall)?;
+                    }
+                    return Ok(());
+                }
+            }
+        }
+
+        self.vhost_handle
+            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+            .map_err(Error::VhostSetVringCall)?;
+        Ok(())
+    }
+
+    fn set_vring_calls(&self) -> Result<()> {
+        if let Some(msix_config) = &self.interrupt.msix_config {
+            let msix_config = msix_config.lock();
+            if msix_config.masked() {
+                for (queue_index, _) in self.queues.iter().enumerate() {
+                    self.vhost_handle
+                        .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                        .map_err(Error::VhostSetVringCall)?;
+                }
+            } else {
+                for (queue_index, queue) in self.queues.iter().enumerate() {
+                    let vector = queue.vector as usize;
+                    if !msix_config.table_masked(vector) {
+                        if let Some(irqfd) = msix_config.get_irqfd(vector) {
+                            self.vhost_handle
+                                .set_vring_call(queue_index, irqfd)
+                                .map_err(Error::VhostSetVringCall)?;
+                        } else {
+                            self.vhost_handle
+                                .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                                .map_err(Error::VhostSetVringCall)?;
+                        }
+                    } else {
+                        self.vhost_handle
+                            .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
+                            .map_err(Error::VhostSetVringCall)?;
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
 }