summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--devices/src/pci/vfio_pci.rs152
-rw-r--r--src/linux.rs7
2 files changed, 138 insertions, 21 deletions
diff --git a/devices/src/pci/vfio_pci.rs b/devices/src/pci/vfio_pci.rs
index af70aef..a1e7a74 100644
--- a/devices/src/pci/vfio_pci.rs
+++ b/devices/src/pci/vfio_pci.rs
@@ -7,10 +7,12 @@ use std::sync::Arc;
 use std::u32;
 
 use kvm::Datamatch;
+use msg_socket::{MsgReceiver, MsgSender};
 use resources::{Alloc, SystemAllocator};
 use sys_util::{error, EventFd};
 
 use vfio_sys::*;
+use vm_control::{MaybeOwnedFd, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse};
 
 use crate::pci::pci_device::{Error as PciDeviceError, PciDevice};
 use crate::pci::PciInterruptPin;
@@ -109,10 +111,13 @@ struct VfioMsiCap {
     ctl: u16,
     address: u64,
     data: u16,
+    vm_socket_irq: VmIrqRequestSocket,
+    irqfd: Option<EventFd>,
+    gsi: Option<u32>,
 }
 
 impl VfioMsiCap {
-    fn new(config: &VfioPciConfig) -> Option<Self> {
+    fn new(config: &VfioPciConfig, vm_socket_irq: VmIrqRequestSocket) -> Option<Self> {
         // msi minimum size is 0xa
         let mut msi_len: u32 = MSI_LENGTH_32BIT;
         let mut cap_next: u32 = config.read_config_byte(PCI_CAPABILITY_LIST).into();
@@ -133,6 +138,9 @@ impl VfioMsiCap {
                     ctl: 0,
                     address: 0,
                     data: 0,
+                    vm_socket_irq,
+                    irqfd: None,
+                    gsi: None,
                 });
             }
             let offset = cap_next + PCI_MSI_NEXT_POINTER;
@@ -157,6 +165,8 @@ impl VfioMsiCap {
         let len = data.len();
         let offset = index as u32 - self.offset;
         let mut ret: Option<VfioMsiChange> = None;
+        let old_address = self.address;
+        let old_data = self.data;
 
         // write msi ctl
         if len == 2 && offset == PCI_MSI_FLAGS {
@@ -165,6 +175,7 @@ impl VfioMsiCap {
             self.ctl = u16::from_le_bytes(value);
             let is_enabled = self.is_msi_enabled();
             if !was_enabled && is_enabled {
+                self.enable();
                 ret = Some(VfioMsiChange::Enable);
             } else if was_enabled && !is_enabled {
                 ret = Some(VfioMsiChange::Disable)
@@ -199,12 +210,79 @@ impl VfioMsiCap {
             self.data = u16::from_le_bytes(value);
         }
 
+        if self.is_msi_enabled() && (old_address != self.address || old_data != self.data) {
+            self.add_msi_route();
+        }
+
         ret
     }
 
     fn is_msi_enabled(&self) -> bool {
         self.ctl & PCI_MSI_FLAGS_ENABLE == PCI_MSI_FLAGS_ENABLE
     }
+
+    fn add_msi_route(&self) {
+        let gsi = match self.gsi {
+            Some(g) => g,
+            None => {
+                error!("Add msi route but gsi is none");
+                return;
+            }
+        };
+        if let Err(e) = self.vm_socket_irq.send(&VmIrqRequest::AddMsiRoute {
+            gsi,
+            msi_address: self.address,
+            msi_data: self.data.into(),
+        }) {
+            error!("failed to send AddMsiRoute request at {:?}", e);
+            return;
+        }
+        match self.vm_socket_irq.recv() {
+            Ok(VmIrqResponse::Err(e)) => error!("failed to call AddMsiRoute request {:?}", e),
+            Ok(_) => {}
+            Err(e) => error!("failed to receive AddMsiRoute response {:?}", e),
+        }
+    }
+
+    fn allocate_one_msi(&mut self) {
+        if self.irqfd.is_none() {
+            match EventFd::new() {
+                Ok(fd) => self.irqfd = Some(fd),
+                Err(e) => {
+                    error!("failed to create eventfd: {:?}", e);
+                    return;
+                }
+            };
+        }
+
+        if let Err(e) = self.vm_socket_irq.send(&VmIrqRequest::AllocateOneMsi {
+            irqfd: MaybeOwnedFd::Borrowed(self.irqfd.as_ref().unwrap().as_raw_fd()),
+        }) {
+            error!("failed to send AllocateOneMsi request: {:?}", e);
+            return;
+        }
+
+        match self.vm_socket_irq.recv() {
+            Ok(VmIrqResponse::AllocateOneMsi { gsi }) => self.gsi = Some(gsi),
+            _ => error!("failed to receive AllocateOneMsi Response"),
+        }
+    }
+
+    fn enable(&mut self) {
+        if self.gsi.is_none() || self.irqfd.is_none() {
+            self.allocate_one_msi();
+        }
+
+        self.add_msi_route();
+    }
+
+    fn get_msi_irqfd(&self) -> Option<&EventFd> {
+        self.irqfd.as_ref()
+    }
+
+    fn get_vm_socket(&self) -> RawFd {
+        self.vm_socket_irq.as_ref().as_raw_fd()
+    }
 }
 
 struct MmioInfo {
@@ -232,10 +310,10 @@ pub struct VfioPciDevice {
 
 impl VfioPciDevice {
     /// Constructs a new Vfio Pci device for the give Vfio device
-    pub fn new(device: VfioDevice) -> Self {
+    pub fn new(device: VfioDevice, vfio_device_socket_irq: VmIrqRequestSocket) -> Self {
         let dev = Arc::new(device);
         let config = VfioPciConfig::new(Arc::clone(&dev));
-        let msi_cap = VfioMsiCap::new(&config);
+        let msi_cap = VfioMsiCap::new(&config, vfio_device_socket_irq);
 
         VfioPciDevice {
             device: dev,
@@ -302,6 +380,47 @@ impl VfioPciDevice {
         }
         self.irq_type = None;
     }
+
+    fn enable_msi(&mut self) {
+        if let Some(irq_type) = &self.irq_type {
+            match irq_type {
+                VfioIrqType::Intx => self.disable_intx(),
+                _ => return,
+            }
+        }
+
+        let irqfd = match &self.msi_cap {
+            Some(cap) => {
+                if let Some(fd) = cap.get_msi_irqfd() {
+                    fd
+                } else {
+                    self.enable_intx();
+                    return;
+                }
+            }
+            None => {
+                self.enable_intx();
+                return;
+            }
+        };
+
+        if let Err(e) = self.device.irq_enable(irqfd, VfioIrqType::Msi) {
+            error!("failed to enable msi: {}", e);
+            self.enable_intx();
+            return;
+        }
+
+        self.irq_type = Some(VfioIrqType::Msi);
+    }
+
+    fn disable_msi(&mut self) {
+        if let Err(e) = self.device.irq_disable(VfioIrqType::Msi) {
+            error!("failed to disable msi: {}", e);
+            return;
+        }
+
+        self.enable_intx();
+    }
 }
 
 impl PciDevice for VfioPciDevice {
@@ -321,6 +440,9 @@ impl PciDevice for VfioPciDevice {
         if let Some(ref interrupt_resample_evt) = self.interrupt_resample_evt {
             fds.push(interrupt_resample_evt.as_raw_fd());
         }
+        if let Some(msi_cap) = &self.msi_cap {
+            fds.push(msi_cap.get_vm_socket());
+        }
         fds
     }
 
@@ -462,27 +584,19 @@ impl PciDevice for VfioPciDevice {
     fn write_config_register(&mut self, reg_idx: usize, offset: u64, data: &[u8]) {
         let start = (reg_idx * 4) as u64 + offset;
 
+        let mut msi_change: Option<VfioMsiChange> = None;
         if let Some(msi_cap) = self.msi_cap.as_mut() {
             if msi_cap.is_msi_reg(start, data.len()) {
-                if let Some(ref interrupt_evt) = self.interrupt_evt {
-                    match msi_cap.write_msi_reg(start, data) {
-                        Some(VfioMsiChange::Enable) => {
-                            if let Err(e) = self.device.irq_enable(interrupt_evt, VfioIrqType::Msi)
-                            {
-                                error!("{}", e);
-                            }
-                        }
-                        Some(VfioMsiChange::Disable) => {
-                            if let Err(e) = self.device.irq_disable(VfioIrqType::Msi) {
-                                error!("{}", e);
-                            }
-                        }
-                        None => (),
-                    }
-                }
+                msi_change = msi_cap.write_msi_reg(start, data);
             }
         }
 
+        match msi_change {
+            Some(VfioMsiChange::Enable) => self.enable_msi(),
+            Some(VfioMsiChange::Disable) => self.disable_msi(),
+            None => (),
+        }
+
         self.device
             .region_write(VFIO_PCI_CONFIG_REGION_INDEX, data, start);
     }
diff --git a/src/linux.rs b/src/linux.rs
index 31b9d32..0266721 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -258,7 +258,6 @@ type Result<T> = std::result::Result<T, Error>;
 enum TaggedControlSocket {
     Vm(VmControlResponseSocket),
     VmMemory(VmMemoryControlResponseSocket),
-    #[allow(dead_code)]
     VmIrq(VmIrqResponseSocket),
 }
 
@@ -1000,10 +999,14 @@ fn create_devices(
     pci_devices.push((usb_controller, simple_jail(&cfg, "xhci.policy")?));
 
     if cfg.vfio.is_some() {
+        let (vfio_host_socket_irq, vfio_device_socket_irq) =
+            msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?;
+        control_sockets.push(TaggedControlSocket::VmIrq(vfio_host_socket_irq));
+
         let vfio_path = cfg.vfio.as_ref().unwrap().as_path();
         let vfiodevice =
             VfioDevice::new(vfio_path, vm, mem.clone()).map_err(Error::CreateVfioDevice)?;
-        let vfiopcidevice = Box::new(VfioPciDevice::new(vfiodevice));
+        let vfiopcidevice = Box::new(VfioPciDevice::new(vfiodevice, vfio_device_socket_irq));
         pci_devices.push((vfiopcidevice, simple_jail(&cfg, "vfio_device.policy")?));
     }