summary refs log tree commit diff
path: root/devices/src/virtio/virtio_pci_device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/virtio_pci_device.rs')
-rw-r--r--devices/src/virtio/virtio_pci_device.rs19
1 files changed, 16 insertions, 3 deletions
diff --git a/devices/src/virtio/virtio_pci_device.rs b/devices/src/virtio/virtio_pci_device.rs
index 2110b75..713329f 100644
--- a/devices/src/virtio/virtio_pci_device.rs
+++ b/devices/src/virtio/virtio_pci_device.rs
@@ -19,6 +19,8 @@ use crate::pci::{
     PciConfiguration, PciDevice, PciDeviceError, PciHeaderType, PciInterruptPin, PciSubclass,
 };
 
+use vm_control::VmIrqRequestSocket;
+
 use self::virtio_pci_common_config::VirtioPciCommonConfig;
 
 pub enum PciCapabilityType {
@@ -172,7 +174,11 @@ pub struct VirtioPciDevice {
 
 impl VirtioPciDevice {
     /// Constructs a new PCI transport for the given virtio device.
-    pub fn new(mem: GuestMemory, device: Box<dyn VirtioDevice>) -> Result<Self> {
+    pub fn new(
+        mem: GuestMemory,
+        device: Box<dyn VirtioDevice>,
+        msi_device_socket: Option<VmIrqRequestSocket>,
+    ) -> Result<Self> {
         let mut queue_evts = Vec::new();
         for _ in device.queue_max_sizes() {
             queue_evts.push(EventFd::new()?)
@@ -186,8 +192,11 @@ impl VirtioPciDevice {
         let pci_device_id = VIRTIO_PCI_DEVICE_ID_BASE + device.device_type() as u16;
 
         let msix_num = device.msix_vectors();
-        let msix_config = if msix_num > 0 {
-            let msix_config = Arc::new(Mutex::new(MsixConfig::new(msix_num)));
+        let msix_config = if msix_num > 0 && msi_device_socket.is_some() {
+            let msix_config = Arc::new(Mutex::new(MsixConfig::new(
+                msix_num,
+                msi_device_socket.unwrap(),
+            )));
             Some(msix_config)
         } else {
             None
@@ -346,6 +355,10 @@ impl PciDevice for VirtioPciDevice {
         if let Some(interrupt_resample_evt) = &self.interrupt_resample_evt {
             fds.push(interrupt_resample_evt.as_raw_fd());
         }
+        if let Some(msix_config) = &self.msix_config {
+            let fd = msix_config.lock().get_msi_socket();
+            fds.push(fd);
+        }
         fds
     }