diff options
Diffstat (limited to 'devices/src/virtio/vhost/worker.rs')
-rw-r--r-- | devices/src/virtio/vhost/worker.rs | 133 |
1 files changed, 124 insertions, 9 deletions
diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs index 03c1066..ca02a63 100644 --- a/devices/src/virtio/vhost/worker.rs +++ b/devices/src/virtio/vhost/worker.rs @@ -4,17 +4,16 @@ use std::os::raw::c_ulonglong; -use sys_util::{EventFd, PollContext, PollToken}; +use sys_util::{error, Error as SysError, EventFd, PollContext, PollToken}; use vhost::Vhost; -use super::control_socket::VhostDevResponseSocket; +use super::control_socket::{VhostDevRequest, VhostDevResponse, VhostDevResponseSocket}; use super::{Error, Result}; use crate::virtio::{Interrupt, Queue}; +use libc::EIO; +use msg_socket::{MsgReceiver, MsgSender}; -/// Worker that takes care of running the vhost device. This mainly involves forwarding interrupts -/// from the vhost driver to the guest VM because crosvm only supports the virtio-mmio transport, -/// which requires a bit to be set in the interrupt status register before triggering the interrupt -/// and the vhost driver doesn't do this for us. +/// Worker that takes care of running the vhost device. pub struct Worker<T: Vhost> { interrupt: Interrupt, queues: Vec<Queue>, @@ -91,9 +90,7 @@ impl<T: Vhost> Worker<T> { self.vhost_handle .set_vring_base(queue_index, 0) .map_err(Error::VhostSetVringBase)?; - self.vhost_handle - .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) - .map_err(Error::VhostSetVringCall)?; + self.set_vring_call_for_entry(queue_index, queue.vector as usize)?; self.vhost_handle .set_vring_kick(queue_index, &queue_evts[queue_index]) .map_err(Error::VhostSetVringKick)?; @@ -106,6 +103,7 @@ impl<T: Vhost> Worker<T> { VhostIrqi { index: usize }, InterruptResample, Kill, + ControlNotify, } let poll_ctx: PollContext<Token> = PollContext::build_with(&[ @@ -119,6 +117,11 @@ impl<T: Vhost> Worker<T> { .add(vhost_int, Token::VhostIrqi { index }) .map_err(Error::CreatePollContext)?; } + if let Some(socket) = &self.response_socket { + poll_ctx + .add(socket, Token::ControlNotify) + .map_err(Error::CreatePollContext)?; + } 'poll: loop { let events = poll_ctx.wait().map_err(Error::PollError)?; @@ -138,10 +141,122 @@ impl<T: Vhost> Worker<T> { let _ = self.kill_evt.read(); break 'poll; } + Token::ControlNotify => { + if let Some(socket) = &self.response_socket { + match socket.recv() { + Ok(VhostDevRequest::MsixEntryChanged(index)) => { + let mut qindex = 0; + for (queue_index, queue) in self.queues.iter().enumerate() { + if queue.vector == index as u16 { + qindex = queue_index; + break; + } + } + let response = + match self.set_vring_call_for_entry(qindex, index) { + Ok(()) => VhostDevResponse::Ok, + Err(e) => { + error!( + "Set vring call failed for masked entry {}: {:?}", + index, e + ); + VhostDevResponse::Err(SysError::new(EIO)) + } + }; + if let Err(e) = socket.send(&response) { + error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e); + } + } + Ok(VhostDevRequest::MsixChanged) => { + let response = match self.set_vring_calls() { + Ok(()) => VhostDevResponse::Ok, + Err(e) => { + error!("Set vring calls failed: {:?}", e); + VhostDevResponse::Err(SysError::new(EIO)) + } + }; + if let Err(e) = socket.send(&response) { + error!( + "Vhost failed to send VhostMsixMasked Response: {:?}", + e + ); + } + } + Err(e) => { + error!("Vhost failed to receive Control request: {:?}", e); + } + } + } + } } } } cleanup_vqs(&self.vhost_handle)?; Ok(()) } + + fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> { + // No response_socket means it doesn't have any control related + // with the msix. Due to this, cannot use the direct irq fd but + // should fall back to indirect irq fd. + if self.response_socket.is_some() { + if let Some(msix_config) = &self.interrupt.msix_config { + let msix_config = msix_config.lock(); + let msix_masked = msix_config.masked(); + if msix_masked { + return Ok(()); + } + if !msix_config.table_masked(vector) { + if let Some(irqfd) = msix_config.get_irqfd(vector) { + self.vhost_handle + .set_vring_call(queue_index, irqfd) + .map_err(Error::VhostSetVringCall)?; + } else { + self.vhost_handle + .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) + .map_err(Error::VhostSetVringCall)?; + } + return Ok(()); + } + } + } + + self.vhost_handle + .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) + .map_err(Error::VhostSetVringCall)?; + Ok(()) + } + + fn set_vring_calls(&self) -> Result<()> { + if let Some(msix_config) = &self.interrupt.msix_config { + let msix_config = msix_config.lock(); + if msix_config.masked() { + for (queue_index, _) in self.queues.iter().enumerate() { + self.vhost_handle + .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) + .map_err(Error::VhostSetVringCall)?; + } + } else { + for (queue_index, queue) in self.queues.iter().enumerate() { + let vector = queue.vector as usize; + if !msix_config.table_masked(vector) { + if let Some(irqfd) = msix_config.get_irqfd(vector) { + self.vhost_handle + .set_vring_call(queue_index, irqfd) + .map_err(Error::VhostSetVringCall)?; + } else { + self.vhost_handle + .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) + .map_err(Error::VhostSetVringCall)?; + } + } else { + self.vhost_handle + .set_vring_call(queue_index, &self.vhost_interrupt[queue_index]) + .map_err(Error::VhostSetVringCall)?; + } + } + } + } + Ok(()) + } } |