summary refs log tree commit diff
diff options
context:
space:
mode:
authorJason D. Clinton <jclinton@chromium.org>2017-09-27 22:04:03 -0600
committerchrome-bot <chrome-bot@chromium.org>2018-02-02 16:32:12 -0800
commit865323d0ed8b6913ed7dfe6e31c3b86eb46775bd (patch)
treebe835e928e8932ab66dc00412d5f96430289e94c
parent19e57b9532f9be830fab7fad685957afc8f5ab78 (diff)
downloadcrosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar.gz
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar.bz2
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar.lz
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar.xz
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.tar.zst
crosvm-865323d0ed8b6913ed7dfe6e31c3b86eb46775bd.zip
hw/virtio/vhost: Add simple tests backed by fakes
This slightly advances the use of fakes to test higher level
application logic. The fakes are rudimentary at this point, but I
wanted to get feedback on the addition of generics in order to
facilitate swaping concrete implementations out with fakes in higher
level code.

BUG=none
TEST=./build_test and
cargo test -p crosvm -p data_model -p syscall_defines -p kernel_loader
-p net_util -p x86_64 -p virtio_sys -p kvm_sys -p vhost -p io_jail -p
net_sys -p sys_util -p kvm

Change-Id: Ib64581014391f49cff30ada10677bbbcd0088f20
Reviewed-on: https://chromium-review.googlesource.com/689740
Commit-Ready: Jason Clinton <jclinton@chromium.org>
Tested-by: Jason Clinton <jclinton@chromium.org>
Reviewed-by: Stephen Barber <smbarber@chromium.org>
-rw-r--r--Cargo.lock2
-rw-r--r--Cargo.toml2
-rw-r--r--devices/src/virtio/net.rs61
-rw-r--r--devices/src/virtio/vhost/net.rs104
-rw-r--r--net_util/src/lib.rs124
-rw-r--r--src/linux.rs8
-rw-r--r--src/main.rs2
-rw-r--r--vhost/src/lib.rs8
-rw-r--r--vhost/src/net.rs69
9 files changed, 293 insertions, 87 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 5c0d367..51c183c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -24,10 +24,12 @@ dependencies = [
  "kernel_loader 0.1.0",
  "kvm 0.1.0",
  "libc 0.2.34 (registry+https://github.com/rust-lang/crates.io-index)",
+ "net_util 0.1.0",
  "plugin_proto 0.5.0",
  "qcow 0.1.0",
  "qcow_utils 0.1.0",
  "sys_util 0.1.0",
+ "vhost 0.1.0",
  "vm_control 0.1.0",
  "x86_64 0.1.0",
 ]
diff --git a/Cargo.toml b/Cargo.toml
index 80a7c5f..c6a5575 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,6 +20,8 @@ sys_util = { path = "sys_util" }
 kernel_loader = { path = "kernel_loader" }
 libc = "=0.2.34"
 byteorder = "=1.1.0"
+net_util = { path = "net_util" }
+vhost = { path = "vhost" }
 vm_control = { path = "vm_control" }
 data_model = { path = "data_model" }
 qcow = { path = "qcow" }
diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs
index a4df8c9..2c15271 100644
--- a/devices/src/virtio/net.rs
+++ b/devices/src/virtio/net.rs
@@ -4,7 +4,6 @@
 
 use std::cmp;
 use std::mem;
-use std::io::{Read, Write};
 use std::net::Ipv4Addr;
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::sync::Arc;
@@ -13,8 +12,8 @@ use std::thread;
 
 use libc::EAGAIN;
 use net_sys;
-use net_util::{Tap, Error as TapError};
-use sys_util::{Error as SysError};
+use net_util::{Error as TapError, TapT};
+use sys_util::Error as SysError;
 use sys_util::{EventFd, GuestMemory, Pollable, Poller};
 use virtio_sys::{vhost, virtio_net};
 use virtio_sys::virtio_net::virtio_net_hdr_v1;
@@ -50,11 +49,11 @@ pub enum NetError {
     PollError(SysError),
 }
 
-struct Worker {
+struct Worker<T: TapT> {
     mem: GuestMemory,
     rx_queue: Queue,
     tx_queue: Queue,
-    tap: Tap,
+    tap: T,
     interrupt_status: Arc<AtomicUsize>,
     interrupt_evt: EventFd,
     rx_buf: [u8; MAX_BUFFER_SIZE],
@@ -66,7 +65,10 @@ struct Worker {
     acked_features: u64,
 }
 
-impl Worker {
+impl<T> Worker<T>
+where
+    T: TapT,
+{
     fn signal_used_queue(&self) {
         self.interrupt_status
             .fetch_or(INTERRUPT_STATUS_USED_RING as usize, Ordering::SeqCst);
@@ -226,7 +228,7 @@ impl Worker {
         const KILL: u32 = 4;
 
         'poll: loop {
-            let tokens = match poller.poll(&[(RX_TAP, &self.tap as &Pollable),
+            let tokens = match poller.poll(&[(RX_TAP, &self.tap),
                                              (RX_QUEUE, &rx_queue_evt as &Pollable),
                                              (TX_QUEUE, &tx_queue_evt as &Pollable),
                                              (KILL, &kill_evt as &Pollable)]) {
@@ -274,21 +276,24 @@ impl Worker {
     }
 }
 
-pub struct Net {
+pub struct Net<T: TapT> {
     workers_kill_evt: Option<EventFd>,
     kill_evt: EventFd,
-    tap: Option<Tap>,
+    tap: Option<T>,
     avail_features: u64,
     acked_features: u64,
 }
 
-impl Net {
+impl<T> Net<T>
+where
+    T: TapT,
+{
     /// Create a new virtio network device with the given IP address and
     /// netmask.
-    pub fn new(ip_addr: Ipv4Addr, netmask: Ipv4Addr) -> Result<Net, NetError> {
+    pub fn new(ip_addr: Ipv4Addr, netmask: Ipv4Addr) -> Result<Net<T>, NetError> {
         let kill_evt = EventFd::new().map_err(NetError::CreateKillEventFd)?;
 
-        let tap = Tap::new().map_err(NetError::TapOpen)?;
+        let tap: T = T::new().map_err(NetError::TapOpen)?;
         tap.set_ip_addr(ip_addr).map_err(NetError::TapSetIp)?;
         tap.set_netmask(netmask)
             .map_err(NetError::TapSetNetmask)?;
@@ -306,24 +311,25 @@ impl Net {
 
         let avail_features =
             1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM | 1 << virtio_net::VIRTIO_NET_F_CSUM |
-            1 << virtio_net::VIRTIO_NET_F_GUEST_TSO4 |
-            1 << virtio_net::VIRTIO_NET_F_GUEST_UFO |
-            1 << virtio_net::VIRTIO_NET_F_HOST_TSO4 |
-            1 << virtio_net::VIRTIO_NET_F_HOST_UFO | 1 << vhost::VIRTIO_F_VERSION_1;
+                1 << virtio_net::VIRTIO_NET_F_GUEST_TSO4 |
+                1 << virtio_net::VIRTIO_NET_F_GUEST_UFO |
+                1 << virtio_net::VIRTIO_NET_F_HOST_TSO4 |
+                1 << virtio_net::VIRTIO_NET_F_HOST_UFO | 1 << vhost::VIRTIO_F_VERSION_1;
 
         Ok(Net {
-               workers_kill_evt: Some(kill_evt
-                                          .try_clone()
-                                          .map_err(NetError::CloneKillEventFd)?),
-               kill_evt: kill_evt,
-               tap: Some(tap),
-               avail_features: avail_features,
-               acked_features: 0u64,
-           })
+            workers_kill_evt: Some(kill_evt.try_clone().map_err(NetError::CloneKillEventFd)?),
+            kill_evt: kill_evt,
+            tap: Some(tap),
+            avail_features: avail_features,
+            acked_features: 0u64,
+        })
     }
 }
 
-impl Drop for Net {
+impl<T> Drop for Net<T>
+where
+    T: TapT,
+{
     fn drop(&mut self) {
         // Only kill the child if it claimed its eventfd.
         if self.workers_kill_evt.is_none() {
@@ -333,7 +339,10 @@ impl Drop for Net {
     }
 }
 
-impl VirtioDevice for Net {
+impl<T> VirtioDevice for Net<T>
+where
+    T: 'static + TapT,
+{
     fn keep_fds(&self) -> Vec<RawFd> {
         let mut keep_fds = Vec::new();
 
diff --git a/devices/src/virtio/vhost/net.rs b/devices/src/virtio/vhost/net.rs
index 473c7e8..fa8df76 100644
--- a/devices/src/virtio/vhost/net.rs
+++ b/devices/src/virtio/vhost/net.rs
@@ -10,10 +10,10 @@ use std::sync::atomic::AtomicUsize;
 use std::thread;
 
 use net_sys;
-use net_util::Tap;
+use net_util::TapT;
+
 use sys_util::{EventFd, GuestMemory};
-use vhost::Net as VhostNetHandle;
-use vhost::net::NetT;
+use vhost::NetT as VhostNetT;
 use virtio_sys::{vhost, virtio_net};
 
 use super::{Error, Result};
@@ -24,23 +24,27 @@ const QUEUE_SIZE: u16 = 256;
 const NUM_QUEUES: usize = 2;
 const QUEUE_SIZES: &'static [u16] = &[QUEUE_SIZE; NUM_QUEUES];
 
-pub struct Net {
+pub struct Net<T: TapT, U: VhostNetT<T>> {
     workers_kill_evt: Option<EventFd>,
     kill_evt: EventFd,
-    tap: Option<Tap>,
-    vhost_net_handle: Option<VhostNetHandle>,
+    tap: Option<T>,
+    vhost_net_handle: Option<U>,
     vhost_interrupt: Option<EventFd>,
     avail_features: u64,
     acked_features: u64,
 }
 
-impl Net {
+impl<T, U> Net<T, U>
+where
+    T: TapT,
+    U: VhostNetT<T>,
+{
     /// Create a new virtio network device with the given IP address and
     /// netmask.
-    pub fn new(ip_addr: Ipv4Addr, netmask: Ipv4Addr, mem: &GuestMemory) -> Result<Net> {
+    pub fn new(ip_addr: Ipv4Addr, netmask: Ipv4Addr, mem: &GuestMemory) -> Result<Net<T, U>> {
         let kill_evt = EventFd::new().map_err(Error::CreateKillEventFd)?;
 
-        let tap = Tap::new().map_err(Error::TapOpen)?;
+        let tap: T = T::new().map_err(Error::TapOpen)?;
         tap.set_ip_addr(ip_addr).map_err(Error::TapSetIp)?;
         tap.set_netmask(netmask).map_err(Error::TapSetNetmask)?;
 
@@ -54,7 +58,7 @@ impl Net {
         tap.set_vnet_hdr_size(vnet_hdr_size).map_err(Error::TapSetVnetHdrSize)?;
 
         tap.enable().map_err(Error::TapEnable)?;
-        let vhost_net_handle = VhostNetHandle::new(mem).map_err(Error::VhostOpen)?;
+        let vhost_net_handle = U::new(mem).map_err(Error::VhostOpen)?;
 
         let avail_features =
             1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM | 1 << virtio_net::VIRTIO_NET_F_CSUM |
@@ -79,7 +83,11 @@ impl Net {
     }
 }
 
-impl Drop for Net {
+impl<T, U> Drop for Net<T, U>
+where
+    T: TapT,
+    U: VhostNetT<T>,
+{
     fn drop(&mut self) {
         // Only kill the child if it claimed its eventfd.
         if self.workers_kill_evt.is_none() {
@@ -89,7 +97,11 @@ impl Drop for Net {
     }
 }
 
-impl VirtioDevice for Net {
+impl<T, U> VirtioDevice for Net<T, U>
+where
+    T: TapT + 'static,
+    U: VhostNetT<T> + 'static,
+{
     fn keep_fds(&self) -> Vec<RawFd> {
         let mut keep_fds = Vec::new();
 
@@ -182,7 +194,7 @@ impl VirtioDevice for Net {
                                                              status,
                                                              interrupt_evt,
                                                              acked_features);
-                                let activate_vqs = |handle: &VhostNetHandle| -> Result<()> {
+                                let activate_vqs = |handle: &U| -> Result<()> {
                                     for idx in 0..NUM_QUEUES {
                                         handle
                                             .set_backend(idx, &tap)
@@ -207,3 +219,69 @@ impl VirtioDevice for Net {
         }
     }
 }
+
+#[cfg(test)]
+pub mod tests {
+    use super::*;
+    use std::result;
+    use net_util::fakes::FakeTap;
+    use sys_util::{GuestAddress, GuestMemory, GuestMemoryError};
+    use vhost::net::fakes::FakeNet;
+
+    fn create_guest_memory() -> result::Result<GuestMemory, GuestMemoryError> {
+        let start_addr1 = GuestAddress(0x0);
+        let start_addr2 = GuestAddress(0x100);
+        GuestMemory::new(&vec![(start_addr1, 0x100), (start_addr2, 0x400)])
+    }
+
+    fn create_net_common() -> Net<FakeTap, FakeNet<FakeTap>> {
+        let guest_memory = create_guest_memory().unwrap();
+        Net::<FakeTap, FakeNet<FakeTap>>::new(
+            Ipv4Addr::new(127, 0, 0, 1),
+            Ipv4Addr::new(255, 255, 255, 0),
+            &guest_memory,
+        ).unwrap()
+    }
+
+    #[test]
+    fn create_net() {
+        create_net_common();
+    }
+
+    #[test]
+    fn keep_fds() {
+        let net = create_net_common();
+        let fds = net.keep_fds();
+        assert!(fds.len() >= 1, "We should have gotten at least one fd");
+    }
+
+    #[test]
+    fn features() {
+        let net = create_net_common();
+        assert_eq!(net.features(0), 822135939);
+        assert_eq!(net.features(1), 1);
+        assert_eq!(net.features(2), 0);
+    }
+
+    #[test]
+    fn ack_features() {
+        let mut net = create_net_common();
+        // Just testing that we don't panic, for now
+        net.ack_features(0, 1);
+        net.ack_features(1, 1);
+    }
+
+    #[test]
+    fn activate() {
+        let mut net = create_net_common();
+        let guest_memory = create_guest_memory().unwrap();
+        // Just testing that we don't panic, for now
+        net.activate(
+            guest_memory,
+            EventFd::new().unwrap(),
+            Arc::new(AtomicUsize::new(0)),
+            vec![Queue::new(1)],
+            vec![EventFd::new().unwrap()],
+        );
+    }
+}
diff --git a/net_util/src/lib.rs b/net_util/src/lib.rs
index 76de2ef..34cc686 100644
--- a/net_util/src/lib.rs
+++ b/net_util/src/lib.rs
@@ -67,9 +67,30 @@ pub struct Tap {
     if_name: [u8; 16usize],
 }
 
-impl Tap {
+pub trait TapT: Read + Write + AsRawFd + Pollable + Send + Sized {
     /// Create a new tap interface.
-    pub fn new() -> Result<Tap> {
+    fn new() -> Result<Self>;
+
+    /// Set the host-side IP address for the tap interface.
+    fn set_ip_addr(&self, ip_addr: net::Ipv4Addr) -> Result<()>;
+
+    /// Set the netmask for the subnet that the tap interface will exist on.
+    fn set_netmask(&self, netmask: net::Ipv4Addr) -> Result<()>;
+
+    /// Set the offload flags for the tap interface.
+    fn set_offload(&self, flags: c_uint) -> Result<()>;
+
+    /// Enable the tap interface.
+    fn enable(&self) -> Result<()>;
+
+    /// Set the size of the vnet hdr.
+    fn set_vnet_hdr_size(&self, size: c_int) -> Result<()>;
+
+    fn get_ifreq(&self) -> net_sys::ifreq;
+}
+
+impl TapT for Tap {
+    fn new() -> Result<Tap> {
         // Open calls are safe because we give a constant nul-terminated
         // string and verify the result.
         let fd = unsafe {
@@ -118,8 +139,7 @@ impl Tap {
            })
     }
 
-    /// Set the host-side IP address for the tap interface.
-    pub fn set_ip_addr(&self, ip_addr: net::Ipv4Addr) -> Result<()> {
+    fn set_ip_addr(&self, ip_addr: net::Ipv4Addr) -> Result<()> {
         let sock = create_socket()?;
         let addr = create_sockaddr(ip_addr);
 
@@ -141,8 +161,7 @@ impl Tap {
         Ok(())
     }
 
-    /// Set the netmask for the subnet that the tap interface will exist on.
-    pub fn set_netmask(&self, netmask: net::Ipv4Addr) -> Result<()> {
+    fn set_netmask(&self, netmask: net::Ipv4Addr) -> Result<()> {
         let sock = create_socket()?;
         let addr = create_sockaddr(netmask);
 
@@ -164,8 +183,7 @@ impl Tap {
         Ok(())
     }
 
-    /// Set the offload flags for the tap interface.
-    pub fn set_offload(&self, flags: c_uint) -> Result<()> {
+    fn set_offload(&self, flags: c_uint) -> Result<()> {
         // ioctl is safe. Called with a valid tap fd, and we check the return.
         let ret =
             unsafe { ioctl_with_val(&self.tap_file, net_sys::TUNSETOFFLOAD(), flags as c_ulong) };
@@ -176,8 +194,7 @@ impl Tap {
         Ok(())
     }
 
-    /// Enable the tap interface.
-    pub fn enable(&self) -> Result<()> {
+    fn enable(&self) -> Result<()> {
         let sock = create_socket()?;
 
         let mut ifreq = self.get_ifreq();
@@ -199,8 +216,7 @@ impl Tap {
         Ok(())
     }
 
-    /// Set the size of the vnet hdr.
-    pub fn set_vnet_hdr_size(&self, size: c_int) -> Result<()> {
+    fn set_vnet_hdr_size(&self, size: c_int) -> Result<()> {
         // ioctl is safe. Called with a valid tap fd, and we check the return.
         let ret = unsafe { ioctl_with_ref(&self.tap_file, net_sys::TUNSETVNETHDRSZ(), &size) };
         if ret < 0 {
@@ -254,6 +270,90 @@ unsafe impl Pollable for Tap {
     }
 }
 
+pub mod fakes {
+    use super::*;
+    use std::fs::OpenOptions;
+    use std::fs::remove_file;
+
+    const TMP_FILE: &str = "/tmp/crosvm_tap_test_file";
+
+    pub struct FakeTap {
+        tap_file: File,
+    }
+
+    impl TapT for FakeTap {
+        fn new() -> Result<FakeTap> {
+            Ok(FakeTap {
+                tap_file: OpenOptions::new()
+                    .read(true)
+                    .append(true)
+                    .create(true)
+                    .open(TMP_FILE)
+                    .unwrap()
+            })
+        }
+
+        fn set_ip_addr(&self, _: net::Ipv4Addr) -> Result<()> {
+            Ok(())
+        }
+
+        fn set_netmask(&self, _: net::Ipv4Addr) -> Result<()> {
+            Ok(())
+        }
+
+        fn set_offload(&self, _: c_uint) -> Result<()> {
+            Ok(())
+        }
+
+        fn enable(&self) -> Result<()> {
+            Ok(())
+        }
+
+        fn set_vnet_hdr_size(&self, _: c_int) -> Result<()> {
+            Ok(())
+        }
+
+        fn get_ifreq(&self) -> net_sys::ifreq {
+            let ifreq: net_sys::ifreq = Default::default();
+            ifreq
+        }
+    }
+
+    impl Drop for FakeTap {
+        fn drop(&mut self) {
+            let _ = remove_file(TMP_FILE);
+        }
+    }
+
+    impl Read for FakeTap {
+        fn read(&mut self, _: &mut [u8]) -> IoResult<usize> {
+            Ok(0)
+        }
+    }
+
+    impl Write for FakeTap {
+        fn write(&mut self, _: &[u8]) -> IoResult<usize> {
+            Ok(0)
+        }
+
+        fn flush(&mut self) -> IoResult<()> {
+            Ok(())
+        }
+    }
+
+    impl AsRawFd for FakeTap {
+        fn as_raw_fd(&self) -> RawFd {
+            self.tap_file.as_raw_fd()
+        }
+    }
+
+    unsafe impl Pollable for FakeTap {
+        fn pollable_fd(&self) -> RawFd {
+            self.tap_file.as_raw_fd()
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/src/linux.rs b/src/linux.rs
index abf39f9..1d0a311 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -22,9 +22,11 @@ use io_jail::{self, Minijail};
 use kernel_cmdline;
 use kernel_loader;
 use kvm::*;
+use net_util::Tap;
 use qcow::{self, QcowFile};
 use sys_util::*;
 use sys_util;
+use vhost;
 use vm_control::VmRequest;
 
 use Config;
@@ -352,10 +354,10 @@ fn setup_mmio_bus(cfg: &Config,
     if let Some(host_ip) = cfg.host_ip {
         if let Some(netmask) = cfg.netmask {
             let net_box: Box<devices::virtio::VirtioDevice> = if cfg.vhost_net {
-                Box::new(devices::virtio::vhost::Net::new(host_ip, netmask, &mem)
-                             .map_err(Error::VhostNetDeviceNew)?)
+                Box::new(devices::virtio::vhost::Net::<Tap, vhost::Net<Tap>>::new(host_ip, netmask, &mem)
+                                   .map_err(|e| Error::VhostNetDeviceNew(e))?)
             } else {
-                Box::new(devices::virtio::Net::new(host_ip, netmask)
+                Box::new(devices::virtio::Net::<Tap>::new(host_ip, netmask)
                                    .map_err(|e| Error::NetDeviceNew(e))?)
             };
 
diff --git a/src/main.rs b/src/main.rs
index fb2d151..4248195 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -12,9 +12,11 @@ extern crate kvm;
 extern crate x86_64;
 extern crate kernel_loader;
 extern crate byteorder;
+extern crate net_util;
 extern crate qcow;
 #[macro_use]
 extern crate sys_util;
+extern crate vhost;
 extern crate vm_control;
 extern crate data_model;
 
diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs
index dc1bf81..e0aefdb 100644
--- a/vhost/src/lib.rs
+++ b/vhost/src/lib.rs
@@ -11,6 +11,7 @@ pub mod net;
 mod vsock;
 
 pub use net::Net;
+pub use net::NetT;
 pub use vsock::Vsock;
 
 use std::io::Error as IoError;
@@ -333,7 +334,8 @@ pub trait Vhost: AsRawFd + std::marker::Sized {
 mod tests {
     use super::*;
 
-    use net::tests::FakeNet;
+    use net::fakes::FakeNet;
+    use net_util::fakes::FakeTap;
     use std::result;
     use sys_util::{GuestAddress, GuestMemory, GuestMemoryError};
 
@@ -352,9 +354,9 @@ mod tests {
         }
     }
 
-    fn create_fake_vhost_net () -> FakeNet {
+    fn create_fake_vhost_net() -> FakeNet<FakeTap> {
         let gm = create_guest_memory().unwrap();
-        FakeNet::new(&gm).unwrap()
+        FakeNet::<FakeTap>::new(&gm).unwrap()
     }
 
     #[test]
diff --git a/vhost/src/net.rs b/vhost/src/net.rs
index 0c6eaad..1d3f94e 100644
--- a/vhost/src/net.rs
+++ b/vhost/src/net.rs
@@ -3,8 +3,9 @@
 // found in the LICENSE file.
 
 use libc;
-use net_util;
+use net_util::TapT;
 use std::fs::{File, OpenOptions};
+use std::marker::PhantomData;
 use std::os::unix::fs::OpenOptionsExt;
 use std::os::unix::io::{AsRawFd, RawFd};
 use virtio_sys;
@@ -19,30 +20,37 @@ static DEVICE: &'static str = "/dev/vhost-net";
 ///
 /// This provides a simple wrapper around a VHOST_NET file descriptor and
 /// methods that safely run ioctls on that file descriptor.
-pub struct Net {
+pub struct Net<T> {
     // fd must be dropped first, which will stop and tear down the
     // vhost-net worker before GuestMemory can potentially be unmapped.
     fd: File,
     mem: GuestMemory,
+    phantom: PhantomData<T>,
 }
 
-pub trait NetT {
+pub trait NetT<T: TapT>: Vhost + AsRawFd + Send + Sized {
+    /// Create a new NetT instance
+    fn new(mem: &GuestMemory) -> Result<Self>;
+
     /// Set the tap file descriptor that will serve as the VHOST_NET backend.
     /// This will start the vhost worker for the given queue.
     ///
     /// # Arguments
     /// * `queue_index` - Index of the queue to modify.
     /// * `fd` - Tap interface that will be used as the backend.
-    fn set_backend(&self, queue_index: usize, fd: &net_util::Tap) -> Result<()>;
+    fn set_backend(&self, queue_index: usize, fd: &T) -> Result<()>;
 }
 
-impl Net {
+impl<T> NetT<T> for Net<T>
+where
+    T: TapT,
+{
     /// Opens /dev/vhost-net and holds a file descriptor open for it.
     ///
     /// # Arguments
     /// * `mem` - Guest memory mapping.
-    pub fn new(mem: &GuestMemory) -> Result<Net> {
-        Ok(Net {
+    fn new(mem: &GuestMemory) -> Result<Net<T>> {
+        Ok(Net::<T> {
             fd: OpenOptions::new()
                 .read(true)
                 .write(true)
@@ -50,12 +58,11 @@ impl Net {
                 .open(DEVICE)
                 .map_err(Error::VhostOpen)?,
             mem: mem.clone(),
+            phantom: PhantomData,
         })
     }
-}
 
-impl NetT for Net {
-    fn set_backend(&self, queue_index: usize, fd: &net_util::Tap) -> Result<()> {
+    fn set_backend(&self, queue_index: usize, fd: &T) -> Result<()> {
         let vring_file = virtio_sys::vhost_vring_file {
             index: queue_index as u32,
             fd: fd.as_raw_fd(),
@@ -72,64 +79,66 @@ impl NetT for Net {
     }
 }
 
-impl Vhost for Net {
+impl<T> Vhost for Net<T> {
     fn mem(&self) -> &GuestMemory {
         &self.mem
     }
 }
 
-impl AsRawFd for Net {
+impl<T> AsRawFd for Net<T> {
     fn as_raw_fd(&self) -> RawFd {
         self.fd.as_raw_fd()
     }
 }
 
-#[cfg(test)]
-pub mod tests {
+pub mod fakes {
     use super::*;
     use std::fs::OpenOptions;
     use std::fs::remove_file;
 
     const TMP_FILE: &str = "/tmp/crosvm_vhost_test_file";
 
-    pub struct FakeNet {
+    pub struct FakeNet<T> {
         fd: File,
         mem: GuestMemory,
+        phantom: PhantomData<T>,
+    }
+
+    impl<T> Drop for FakeNet<T> {
+        fn drop(&mut self) {
+            let _ = remove_file(TMP_FILE);
+        }
     }
 
-    impl FakeNet {
-        pub fn new(mem: &GuestMemory) -> Result<FakeNet> {
-            Ok(FakeNet {
+    impl<T> NetT<T> for FakeNet<T>
+    where
+        T: TapT,
+    {
+        fn new(mem: &GuestMemory) -> Result<FakeNet<T>> {
+            Ok(FakeNet::<T> {
                 fd: OpenOptions::new()
                     .read(true)
                     .append(true)
                     .create(true)
                     .open(TMP_FILE)
                     .unwrap(),
-                mem: mem.clone()
+                mem: mem.clone(),
+                phantom: PhantomData,
             })
         }
-    }
-
-    impl Drop for FakeNet {
-        fn drop(&mut self) {
-            let _ = remove_file(TMP_FILE);
-        }
-    }
 
-    impl NetT for FakeNet {
-        fn set_backend(&self, _queue_index: usize, _fd: &net_util::Tap) -> Result<()> {
+        fn set_backend(&self, _queue_index: usize, _fd: &T) -> Result<()> {
             Ok(())
         }
     }
 
-    impl Vhost for FakeNet {
+    impl<T> Vhost for FakeNet<T> {
         fn mem(&self) -> &GuestMemory {
             &self.mem
         }
     }
 
-    impl AsRawFd for FakeNet {
+    impl<T> AsRawFd for FakeNet<T> {
         fn as_raw_fd(&self) -> RawFd {
             self.fd.as_raw_fd()
         }