summary refs log tree commit diff
path: root/devices/src/virtio/net.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/net.rs')
-rw-r--r--devices/src/virtio/net.rs20
1 files changed, 12 insertions, 8 deletions
diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs
index 0200daf..a15ab03 100644
--- a/devices/src/virtio/net.rs
+++ b/devices/src/virtio/net.rs
@@ -28,8 +28,6 @@ use super::{
 };
 
 const QUEUE_SIZE: u16 = 256;
-const NUM_QUEUES: usize = 3;
-const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE, QUEUE_SIZE, QUEUE_SIZE];
 
 #[derive(Debug)]
 pub enum NetError {
@@ -422,6 +420,7 @@ where
 }
 
 pub struct Net<T: TapT> {
+    queue_sizes: Box<[u16]>,
     workers_kill_evt: Vec<EventFd>,
     kill_evts: Vec<EventFd>,
     worker_threads: Vec<thread::JoinHandle<Worker<T>>>,
@@ -440,8 +439,8 @@ where
         ip_addr: Ipv4Addr,
         netmask: Ipv4Addr,
         mac_addr: MacAddress,
+        vq_pairs: u16,
     ) -> Result<Net<T>, NetError> {
-        let vq_pairs = QUEUE_SIZES.len() as u16 / 2;
         let multi_queue = if vq_pairs > 1 { true } else { false };
         let tap: T = T::new(true, multi_queue).map_err(NetError::TapOpen)?;
         tap.set_ip_addr(ip_addr).map_err(NetError::TapSetIp)?;
@@ -490,6 +489,7 @@ where
         }
 
         Ok(Net {
+            queue_sizes: vec![QUEUE_SIZE; (vq_pairs * 2 + 1) as usize].into_boxed_slice(),
             workers_kill_evt,
             kill_evts,
             worker_threads: Vec::new(),
@@ -500,7 +500,7 @@ where
     }
 
     fn build_config(&self) -> VirtioNetConfig {
-        let vq_pairs = QUEUE_SIZES.len() as u16 / 2;
+        let vq_pairs = self.queue_sizes.len() as u16 / 2;
 
         VirtioNetConfig {
             max_vq_pairs: Le16::from(vq_pairs),
@@ -600,7 +600,7 @@ where
     }
 
     fn queue_max_sizes(&self) -> &[u16] {
-        QUEUE_SIZES
+        &self.queue_sizes
     }
 
     fn features(&self) -> u64 {
@@ -643,12 +643,16 @@ where
         mut queues: Vec<Queue>,
         mut queue_evts: Vec<EventFd>,
     ) {
-        if queues.len() != NUM_QUEUES || queue_evts.len() != NUM_QUEUES {
-            error!("net: expected {} queues, got {}", NUM_QUEUES, queues.len());
+        if queues.len() != self.queue_sizes.len() || queue_evts.len() != self.queue_sizes.len() {
+            error!(
+                "net: expected {} queues, got {}",
+                self.queue_sizes.len(),
+                queues.len()
+            );
             return;
         }
 
-        let vq_pairs = QUEUE_SIZES.len() / 2;
+        let vq_pairs = self.queue_sizes.len() / 2;
         if self.taps.len() != vq_pairs {
             error!("net: expected {} taps, got {}", vq_pairs, self.taps.len());
             return;