summary refs log tree commit diff
path: root/devices/src/virtio/controller.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/controller.rs')
-rw-r--r--devices/src/virtio/controller.rs851
1 files changed, 770 insertions, 81 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 07f5cbd..9b0784c 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -28,25 +28,31 @@
 //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
 //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
 
+use std::os::unix::prelude::*;
+
 use std::collections::BTreeMap as Map;
+use std::fmt::{self, Formatter};
 use std::os::unix::io::{AsRawFd, RawFd};
 use std::path::PathBuf;
 use std::sync::Arc;
 use std::thread;
 
+use msg_socket::{MsgReceiver, MsgSocket};
+use msg_socket2::{DeserializeWithFds, DeserializerWithFds, SerializeWithFds, SerializerWithFds};
+use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess};
+use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer};
+use serde::{Deserialize, Serialize};
+use sys_util::net::UnixSeqpacket;
+use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory};
+
 use super::resource_bridge::*;
-use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice, TYPE_WL, VIRTIO_F_VERSION_1};
+use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice};
 use crate::{
     pci::{PciAddress, PciBarConfiguration, PciCapability, PciCapabilityID},
     MemoryParams,
 };
 use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket};
 
-use msg_socket::{MsgOnSocket, MsgReceiver, MsgSocket};
-use serde::{Deserialize, Serialize};
-use sys_util::net::UnixSeqpacket;
-use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory};
-
 // As far as I can tell, these never change on the other side, so it's
 // fine to just copy them over.
 #[derive(Clone, Debug, Deserialize, Serialize)]
@@ -74,8 +80,8 @@ impl PciCapability for RemotePciCapability {
     }
 }
 
-#[derive(Debug, MsgOnSocket)]
-pub enum MsgOnSocketRequest {
+#[derive(Debug)]
+pub enum Request {
     Create {
         // wayland_paths: Map<String, PathBuf>,
         vm_socket: MaybeOwnedFd<UnixSeqpacket>,
@@ -83,11 +89,24 @@ pub enum MsgOnSocketRequest {
         memory_params: MemoryParams,
     },
 
+    DebugLabel,
+
     DeviceType,
 
+    QueueMaxSizes,
+
     Features,
     AckFeatures(u64),
 
+    ReadConfig {
+        offset: u64,
+        len: usize,
+    },
+    WriteConfig {
+        offset: u64,
+        data: Vec<u8>,
+    },
+
     Activate {
         shm: MaybeOwnedFd<SharedMemory>,
         interrupt: MaybeOwnedFd<UnixSeqpacket>,
@@ -99,75 +118,748 @@ pub enum MsgOnSocketRequest {
     },
 
     Reset,
+
+    GetDeviceBars(PciAddress),
+    GetDeviceCaps,
+
     Kill,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-pub enum BincodeRequest {
-    DebugLabel,
+impl SerializeWithFds for Request {
+    fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        use Request::*;
 
-    QueueMaxSizes,
+        match self {
+            Create {
+                vm_socket,
+                memory_params,
+            } => {
+                let mut sv = serializer
+                    .serializer
+                    .serialize_struct_variant("Request", 0, "Create", 1)?;
 
-    ReadConfig { offset: u64, len: usize },
-    WriteConfig { offset: u64, data: Vec<u8> },
+                sv.skip_field("vm_socket")?;
+                serializer.fds.push(vm_socket.as_raw_fd());
 
-    GetDeviceBars(PciAddress),
-    GetDeviceCaps,
-}
+                sv.serialize_field("memory_params", memory_params)?;
+
+                sv.end()
+            }
+
+            DebugLabel => serializer
+                .serializer
+                .serialize_unit_variant("Request", 1, "DebugLabel"),
+
+            DeviceType => serializer
+                .serializer
+                .serialize_unit_variant("Request", 2, "DeviceType"),
+            QueueMaxSizes => {
+                serializer
+                    .serializer
+                    .serialize_unit_variant("Request", 3, "QueueMaxSizes")
+            }
+
+            Features => serializer
+                .serializer
+                .serialize_unit_variant("Request", 4, "Features"),
+
+            AckFeatures(features) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Request",
+                    5,
+                    "AckFeatures",
+                    1,
+                )?;
+                tv.serialize_field(features)?;
+                tv.end()
+            }
+
+            ReadConfig { offset, len } => {
+                let mut sv = serializer.serializer.serialize_struct_variant(
+                    "Request",
+                    6,
+                    "ReadConfig",
+                    2,
+                )?;
+                sv.serialize_field("offset", offset)?;
+                sv.serialize_field("len", len)?;
+                sv.end()
+            }
+
+            WriteConfig { offset, data } => {
+                let mut sv = serializer.serializer.serialize_struct_variant(
+                    "Request",
+                    7,
+                    "WriteConfig",
+                    2,
+                )?;
+                sv.serialize_field("offset", offset)?;
+                sv.serialize_field("data", data)?;
+                sv.end()
+            }
+
+            Activate {
+                shm,
+                interrupt,
+                interrupt_resample_evt,
+                in_queue,
+                out_queue,
+                in_queue_evt,
+                out_queue_evt,
+            } => {
+                let mut sv = serializer
+                    .serializer
+                    .serialize_struct_variant("Request", 8, "Activate", 2)?;
 
-pub type Request = poly_msg_socket::Value<MsgOnSocketRequest, BincodeRequest>;
+                sv.skip_field("shm")?;
+                serializer.fds.push(shm.as_raw_fd());
 
-impl From<MsgOnSocketRequest> for Request {
-    fn from(request: MsgOnSocketRequest) -> Self {
-        Self::MsgOnSocket(request)
+                sv.skip_field("interrupt")?;
+                serializer.fds.push(interrupt.as_raw_fd());
+
+                sv.skip_field("interrupt_resample_evt")?;
+                serializer.fds.push(interrupt_resample_evt.as_raw_fd());
+
+                sv.serialize_field("in_queue", in_queue)?;
+                sv.serialize_field("out_queue", out_queue)?;
+
+                sv.skip_field("in_queue_evt")?;
+                serializer.fds.push(in_queue_evt.as_raw_fd());
+
+                sv.skip_field("out_queue_evt")?;
+                serializer.fds.push(out_queue_evt.as_raw_fd());
+
+                sv.end()
+            }
+
+            Reset => serializer
+                .serializer
+                .serialize_unit_variant("Request", 9, "Reset"),
+
+            GetDeviceBars(address) => {
+                let mut sv = serializer.serializer.serialize_struct_variant(
+                    "Request",
+                    10,
+                    "GetDeviceBars",
+                    1,
+                )?;
+                sv.serialize_field("address", address)?;
+                sv.end()
+            }
+
+            GetDeviceCaps => {
+                serializer
+                    .serializer
+                    .serialize_unit_variant("Request", 11, "GetDeviceCaps")
+            }
+
+            Kill => serializer
+                .serializer
+                .serialize_unit_variant("Request", 12, "Kill"),
+        }
     }
 }
 
-impl From<BincodeRequest> for Request {
-    fn from(request: BincodeRequest) -> Self {
-        Self::Bincode(request)
+impl<'de> DeserializeWithFds<'de> for Request {
+    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    where
+        I: Iterator<Item = RawFd>,
+        De: Deserializer<'de>,
+    {
+        struct Visitor<'iter, Iter> {
+            fds: &'iter mut Iter,
+        }
+
+        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+        where
+            Iter: Iterator<Item = RawFd>,
+        {
+            type Value = Request;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "enum Request")
+            }
+
+            fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
+                #[derive(Debug, Deserialize)]
+                enum Variant {
+                    Create,
+                    DebugLabel,
+                    DeviceType,
+                    QueueMaxSizes,
+                    Features,
+                    AckFeatures,
+                    ReadConfig,
+                    WriteConfig,
+                    Activate,
+                    Reset,
+                    GetDeviceBars,
+                    GetDeviceCaps,
+                    Kill,
+                }
+
+                match data.variant()? {
+                    (Variant::Create, variant) => {
+                        struct Visitor<'iter, Iter> {
+                            fds: &'iter mut Iter,
+                        }
+
+                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+                        where
+                            Iter: Iterator<Item = RawFd>,
+                        {
+                            type Value = Request;
+
+                            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                                write!(f, "struct variant Request::Create")
+                            }
+
+                            fn visit_seq<A: SeqAccess<'de>>(
+                                self,
+                                mut seq: A,
+                            ) -> Result<Request, A::Error> {
+                                use serde::de::Error;
+
+                                Ok(Request::Create {
+                                    vm_socket: match self.fds.next() {
+                                        Some(vm_socket) => MaybeOwnedFd::Owned(unsafe {
+                                            UnixSeqpacket::from_raw_fd(vm_socket)
+                                        }),
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                0,
+                                                &"struct variant Request::Create with 2 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    memory_params: match seq.next_element()? {
+                                        Some(memory_params) => memory_params,
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                1,
+                                                &"struct variant Request::Create with 2 elements",
+                                            ))
+                                        }
+                                    },
+                                })
+                            }
+                        }
+
+                        let visitor = Visitor { fds: self.fds };
+                        variant.struct_variant(&["memory_params"], visitor)
+                    }
+
+                    (Variant::DebugLabel, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::DebugLabel)
+                    }
+
+                    (Variant::DeviceType, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::DeviceType)
+                    }
+
+                    (Variant::QueueMaxSizes, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::QueueMaxSizes)
+                    }
+
+                    (Variant::Features, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::Features)
+                    }
+
+                    (Variant::AckFeatures, variant) => {
+                        Ok(Request::AckFeatures(variant.newtype_variant()?))
+                    }
+
+                    (Variant::ReadConfig, variant) => {
+                        struct Visitor<'iter, Iter> {
+                            fds: &'iter mut Iter,
+                        }
+
+                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+                        where
+                            Iter: Iterator<Item = RawFd>,
+                        {
+                            type Value = Request;
+
+                            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                                write!(f, "struct variant Request::ReadConfig")
+                            }
+
+                            fn visit_seq<A: SeqAccess<'de>>(
+                                self,
+                                mut seq: A,
+                            ) -> Result<Request, A::Error> {
+                                use serde::de::Error;
+
+                                Ok(Request::ReadConfig {
+                                    offset: match seq.next_element()? {
+                                        Some(offset) => offset,
+                                        None => return Err(Error::invalid_length(
+                                            0,
+                                            &"struct variant Request::ReadConfig with 2 elements",
+                                        )),
+                                    },
+
+                                    len: match seq.next_element()? {
+                                        Some(len) => len,
+                                        None => return Err(Error::invalid_length(
+                                            1,
+                                            &"struct variant Request::ReadConfig with 2 elements",
+                                        )),
+                                    },
+                                })
+                            }
+                        }
+
+                        let visitor = Visitor { fds: self.fds };
+                        variant.struct_variant(&["offset", "len"], visitor)
+                    }
+
+                    (Variant::WriteConfig, variant) => {
+                        struct Visitor<'iter, Iter> {
+                            fds: &'iter mut Iter,
+                        }
+
+                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+                        where
+                            Iter: Iterator<Item = RawFd>,
+                        {
+                            type Value = Request;
+
+                            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                                write!(f, "struct variant Request::WriteConfig")
+                            }
+
+                            fn visit_seq<A: SeqAccess<'de>>(
+                                self,
+                                mut seq: A,
+                            ) -> Result<Request, A::Error> {
+                                use serde::de::Error;
+
+                                Ok(Request::WriteConfig {
+                                    offset: match seq.next_element()? {
+                                        Some(offset) => offset,
+                                        None => return Err(Error::invalid_length(
+                                            0,
+                                            &"struct variant Request::WriteConfig with 2 elements",
+                                        )),
+                                    },
+
+                                    data: match seq.next_element()? {
+                                        Some(data) => data,
+                                        None => return Err(Error::invalid_length(
+                                            1,
+                                            &"struct variant Request::WriteConfig with 2 elements",
+                                        )),
+                                    },
+                                })
+                            }
+                        }
+
+                        let visitor = Visitor { fds: self.fds };
+                        variant.struct_variant(&["offset", "data"], visitor)
+                    }
+
+                    (Variant::Activate, variant) => {
+                        struct Visitor<'iter, Iter> {
+                            fds: &'iter mut Iter,
+                        }
+
+                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+                        where
+                            Iter: Iterator<Item = RawFd>,
+                        {
+                            type Value = Request;
+
+                            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                                write!(f, "struct variant Request::Activate")
+                            }
+
+                            fn visit_seq<A: SeqAccess<'de>>(
+                                self,
+                                mut seq: A,
+                            ) -> Result<Request, A::Error> {
+                                use serde::de::Error;
+
+                                Ok(Request::Activate {
+                                    shm: match self.fds.next() {
+                                        Some(shm) => MaybeOwnedFd::Owned(unsafe {
+                                            FromRawFd::from_raw_fd(shm)
+                                        }),
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                0,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    interrupt: match self.fds.next() {
+                                        Some(interrupt) => MaybeOwnedFd::Owned(unsafe {
+                                            UnixSeqpacket::from_raw_fd(interrupt)
+                                        }),
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                1,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    interrupt_resample_evt: match self.fds.next() {
+                                        Some(interrupt_resample_evt) => {
+                                            MaybeOwnedFd::Owned(unsafe {
+                                                EventFd::from_raw_fd(interrupt_resample_evt)
+                                            })
+                                        }
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                2,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    in_queue: match seq.next_element()? {
+                                        Some(in_queue) => in_queue,
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                3,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    out_queue: match seq.next_element()? {
+                                        Some(out_queue) => out_queue,
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                4,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    in_queue_evt: match self.fds.next() {
+                                        Some(in_queue_evt) => MaybeOwnedFd::Owned(unsafe {
+                                            EventFd::from_raw_fd(in_queue_evt)
+                                        }),
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                5,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+
+                                    out_queue_evt: match self.fds.next() {
+                                        Some(out_queue_evt) => MaybeOwnedFd::Owned(unsafe {
+                                            EventFd::from_raw_fd(out_queue_evt)
+                                        }),
+                                        None => {
+                                            return Err(Error::invalid_length(
+                                                6,
+                                                &"struct variant Request::Activate with 7 elements",
+                                            ))
+                                        }
+                                    },
+                                })
+                            }
+                        }
+
+                        let visitor = Visitor { fds: self.fds };
+                        variant.struct_variant(&["in_queue", "out_queue"], visitor)
+                    }
+
+                    (Variant::Reset, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::Reset)
+                    }
+
+                    (Variant::GetDeviceBars, variant) => {
+                        struct Visitor<'iter, Iter> {
+                            fds: &'iter mut Iter,
+                        }
+
+                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+                        where
+                            Iter: Iterator<Item = RawFd>,
+                        {
+                            type Value = Request;
+
+                            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                                write!(f, "struct variant Request::GetDeviceBars")
+                            }
+
+                            fn visit_seq<A: SeqAccess<'de>>(
+                                self,
+                                mut seq: A,
+                            ) -> Result<Request, A::Error> {
+                                use serde::de::Error;
+
+                                Ok(Request::GetDeviceBars(match seq.next_element()? {
+                                    Some(address) => address,
+                                    None => {
+                                        return Err(Error::invalid_length(
+                                            0,
+                                            &"struct variant Request::GetDeviceBars with 1 element",
+                                        ))
+                                    }
+                                }))
+                            }
+                        }
+
+                        let visitor = Visitor { fds: self.fds };
+                        variant.struct_variant(&["bus", "dev"], visitor)
+                    }
+
+                    (Variant::GetDeviceCaps, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::GetDeviceCaps)
+                    }
+
+                    (Variant::Kill, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Request::Kill)
+                    }
+                }
+            }
+        }
+
+        let DeserializerWithFds {
+            mut fds,
+            deserializer,
+        } = deserializer;
+        let visitor = Visitor { fds: &mut fds };
+        deserializer.deserialize_enum(
+            "Request",
+            &[
+                "Create",
+                "DebugLabel",
+                "DeviceType",
+                "QueueMaxSizes",
+                "Features",
+                "AckFeatures",
+                "ReadConfig",
+                "WriteConfig",
+                "Activate",
+                "Reset",
+                "GetDeviceBars",
+                "GetDeviceCaps",
+                "Kill",
+            ],
+            visitor,
+        )
     }
 }
 
-#[derive(Debug, MsgOnSocket)]
-pub enum MsgOnSocketResponse {
+#[derive(Debug)]
+pub enum Response {
+    DebugLabel(String),
     DeviceType(u32),
+    QueueMaxSizes(Vec<u16>),
     Features(u64),
+    ReadConfig(Vec<u8>),
     Reset(bool),
+    GetDeviceBars(Vec<PciBarConfiguration>),
+    GetDeviceCaps(Vec<RemotePciCapability>),
     Kill,
 }
 
-#[derive(Debug, Deserialize, Serialize)]
-pub enum BincodeResponse {
-    DebugLabel(String),
+impl SerializeWithFds for Response {
+    fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        use Response::*;
+
+        match self {
+            DebugLabel(label) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    0,
+                    "DebugLabel",
+                    1,
+                )?;
+                tv.serialize_field(label)?;
+                tv.end()
+            }
 
-    QueueMaxSizes(Vec<u16>),
+            DeviceType(device_type) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    1,
+                    "DeviceType",
+                    1,
+                )?;
+                tv.serialize_field(device_type)?;
+                tv.end()
+            }
 
-    ReadConfig(Vec<u8>),
+            QueueMaxSizes(sizes) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    2,
+                    "QueueMaxSizes",
+                    1,
+                )?;
+                tv.serialize_field(sizes)?;
+                tv.end()
+            }
 
-    GetDeviceBars(Vec<PciBarConfiguration>),
-    GetDeviceCaps(Vec<RemotePciCapability>),
-}
+            Features(features) => {
+                let mut tv = serializer
+                    .serializer
+                    .serialize_tuple_variant("Response", 3, "Features", 1)?;
+                tv.serialize_field(features)?;
+                tv.end()
+            }
+
+            ReadConfig(config) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    4,
+                    "ReadConfig",
+                    1,
+                )?;
+                tv.serialize_field(config)?;
+                tv.end()
+            }
+
+            Reset(success) => {
+                let mut tv = serializer
+                    .serializer
+                    .serialize_tuple_variant("Response", 5, "Reset", 1)?;
+                tv.serialize_field(success)?;
+                tv.end()
+            }
 
-pub type Response = poly_msg_socket::Value<MsgOnSocketResponse, BincodeResponse>;
+            GetDeviceBars(bars) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    6,
+                    "GetDeviceBars",
+                    1,
+                )?;
+                tv.serialize_field(bars)?;
+                tv.end()
+            }
+
+            GetDeviceCaps(caps) => {
+                let mut tv = serializer.serializer.serialize_tuple_variant(
+                    "Response",
+                    7,
+                    "GetDeviceCaps",
+                    1,
+                )?;
+                tv.serialize_field(caps)?;
+                tv.end()
+            }
 
-impl From<MsgOnSocketResponse> for Response {
-    fn from(response: MsgOnSocketResponse) -> Self {
-        Self::MsgOnSocket(response)
+            Kill => serializer
+                .serializer
+                .serialize_unit_variant("Response", 8, "Kill"),
+        }
     }
 }
 
-impl From<BincodeResponse> for Response {
-    fn from(response: BincodeResponse) -> Self {
-        Self::Bincode(response)
+impl<'de> DeserializeWithFds<'de> for Response {
+    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    where
+        I: Iterator<Item = RawFd>,
+        De: Deserializer<'de>,
+    {
+        struct Visitor<'iter, Iter> {
+            fds: &'iter mut Iter,
+        }
+
+        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
+        where
+            Iter: Iterator<Item = RawFd>,
+        {
+            type Value = Response;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "enum Response")
+            }
+
+            fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
+                #[derive(Debug, Deserialize)]
+                enum Variant {
+                    DebugLabel,
+                    DeviceType,
+                    QueueMaxSizes,
+                    Features,
+                    ReadConfig,
+                    Reset,
+                    GetDeviceBars,
+                    GetDeviceCaps,
+                    Kill,
+                }
+
+                match data.variant()? {
+                    (Variant::DebugLabel, variant) => {
+                        Ok(Response::DebugLabel(variant.newtype_variant()?))
+                    }
+                    (Variant::DeviceType, variant) => {
+                        Ok(Response::DeviceType(variant.newtype_variant()?))
+                    }
+                    (Variant::QueueMaxSizes, variant) => {
+                        Ok(Response::QueueMaxSizes(variant.newtype_variant()?))
+                    }
+                    (Variant::Features, variant) => {
+                        Ok(Response::Features(variant.newtype_variant()?))
+                    }
+                    (Variant::ReadConfig, variant) => {
+                        Ok(Response::ReadConfig(variant.newtype_variant()?))
+                    }
+                    (Variant::Reset, variant) => Ok(Response::Reset(variant.newtype_variant()?)),
+                    (Variant::GetDeviceBars, variant) => {
+                        Ok(Response::GetDeviceBars(variant.newtype_variant()?))
+                    }
+                    (Variant::GetDeviceCaps, variant) => {
+                        Ok(Response::GetDeviceCaps(variant.newtype_variant()?))
+                    }
+
+                    (Variant::Kill, variant) => {
+                        variant.unit_variant()?;
+                        Ok(Response::Kill)
+                    }
+                }
+            }
+        }
+
+        let DeserializerWithFds {
+            mut fds,
+            deserializer,
+        } = deserializer;
+        let visitor = Visitor { fds: &mut fds };
+        deserializer.deserialize_enum(
+            "Response",
+            &[
+                "DebugLabel",
+                "DeviceType",
+                "QueueMaxSizes",
+                "Features",
+                "ReadConfig",
+                "Reset",
+                "GetDeviceBars",
+                "GetDeviceCaps",
+                "Kill",
+            ],
+            visitor,
+        )
     }
 }
 
-use poly_msg_socket::PolyMsgSocket;
-type Socket =
-    PolyMsgSocket<MsgOnSocketRequest, MsgOnSocketResponse, BincodeRequest, BincodeResponse>;
-
-const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
+type Socket = msg_socket2::Socket<Request, Response>;
 
 // TODO: support arbitrary number of queues
 const QUEUE_SIZE: u16 = 16;
@@ -196,9 +888,8 @@ impl Worker {
     }
 
     fn handle_response(&mut self) {
-        use poly_msg_socket::Value::*;
         match self.device_socket.recv() {
-            Ok(MsgOnSocket(MsgOnSocketResponse::Kill)) => {
+            Ok(Response::Kill) => {
                 self.shutdown = true;
             }
 
@@ -224,9 +915,7 @@ impl Worker {
     }
 
     fn kill(&self) {
-        if let Err(e) = self.device_socket.send(poly_msg_socket::Value::MsgOnSocket(
-            MsgOnSocketRequest::Kill,
-        )) {
+        if let Err(e) = self.device_socket.send(Request::Kill) {
             error!("failed to send Kill: {}", e);
         }
     }
@@ -282,8 +971,8 @@ impl Controller {
         resource_bridge: Option<ResourceRequestSocket>,
         memory_params: MemoryParams,
         socket: Socket,
-    ) -> Result<Controller, poly_msg_socket::Error> {
-        socket.send(MsgOnSocketRequest::Create {
+    ) -> Result<Controller, msg_socket2::Error> {
+        socket.send(Request::Create {
             // wayland_paths,
             vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket),
             // resource_bridge,
@@ -313,12 +1002,12 @@ impl Drop for Controller {
 
 impl VirtioDevice for Controller {
     fn debug_label(&self) -> String {
-        if let Err(e) = self.socket.send(BincodeRequest::DebugLabel) {
+        if let Err(e) = self.socket.send(Request::DebugLabel) {
             return format!("remote virtio (unknown type; {})", e);
         }
 
-        let label = match self.socket.recv_bincode() {
-            Ok(BincodeResponse::DebugLabel(label)) => label,
+        let label = match self.socket.recv() {
+            Ok(Response::DebugLabel(label)) => label,
             response => panic!("bad response to DebugLabel: {:?}", response),
         };
 
@@ -338,12 +1027,12 @@ impl VirtioDevice for Controller {
     }
 
     fn device_type(&self) -> u32 {
-        if let Err(e) = self.socket.send(MsgOnSocketRequest::DeviceType) {
+        if let Err(e) = self.socket.send(Request::DeviceType) {
             panic!("failed to send DeviceType: {}", e);
         }
 
-        match self.socket.recv_msg_on_socket() {
-            Ok(MsgOnSocketResponse::DeviceType(device_type)) => device_type,
+        match self.socket.recv() {
+            Ok(Response::DeviceType(device_type)) => device_type,
             response => {
                 panic!("bad response to Reset: {:?}", response);
             }
@@ -351,25 +1040,25 @@ impl VirtioDevice for Controller {
     }
 
     fn queue_max_sizes(&self) -> Vec<u16> {
-        if let Err(e) = self.socket.send(BincodeRequest::QueueMaxSizes) {
+        if let Err(e) = self.socket.send(Request::QueueMaxSizes) {
             panic!("failed to send QueueMaxSizes: {}", e);
         }
 
-        match self.socket.recv_bincode() {
-            Ok(BincodeResponse::QueueMaxSizes(sizes)) => sizes,
+        match self.socket.recv() {
+            Ok(Response::QueueMaxSizes(sizes)) => sizes,
             response => {
-                panic!("bad response to Reset: {:?}", response);
+                panic!("bad response to QueueMaxSizes: {:?}", response);
             }
         }
     }
 
     fn features(&self) -> u64 {
-        if let Err(e) = self.socket.send(MsgOnSocketRequest::Features) {
+        if let Err(e) = self.socket.send(Request::Features) {
             panic!("failed to send Features: {}", e);
         }
 
-        match self.socket.recv_msg_on_socket() {
-            Ok(MsgOnSocketResponse::Features(features)) => features,
+        match self.socket.recv() {
+            Ok(Response::Features(features)) => features,
             response => {
                 panic!("bad response to Reset: {:?}", response);
             }
@@ -377,7 +1066,7 @@ impl VirtioDevice for Controller {
     }
 
     fn ack_features(&mut self, value: u64) {
-        if let Err(e) = self.socket.send(MsgOnSocketRequest::AckFeatures(value)) {
+        if let Err(e) = self.socket.send(Request::AckFeatures(value)) {
             panic!("failed to send AckFeatures: {}", e);
         }
     }
@@ -385,12 +1074,12 @@ impl VirtioDevice for Controller {
     fn read_config(&self, offset: u64, data: &mut [u8]) {
         let len = data.len();
 
-        if let Err(e) = self.socket.send(BincodeRequest::ReadConfig { offset, len }) {
+        if let Err(e) = self.socket.send(Request::ReadConfig { offset, len }) {
             panic!("failed to send ReadConfig: {}", e);
         }
 
-        match self.socket.recv_bincode() {
-            Ok(BincodeResponse::ReadConfig(response)) => {
+        match self.socket.recv() {
+            Ok(Response::ReadConfig(response)) => {
                 data.copy_from_slice(&response[..len]); // TODO: test no panic
             }
             response => panic!("bad response to ReadConfig: {:?}", response),
@@ -398,7 +1087,7 @@ impl VirtioDevice for Controller {
     }
 
     fn write_config(&mut self, offset: u64, data: &[u8]) {
-        if let Err(e) = self.socket.send(BincodeRequest::WriteConfig {
+        if let Err(e) = self.socket.send(Request::WriteConfig {
             offset,
             data: data.to_vec(),
         }) {
@@ -431,7 +1120,7 @@ impl VirtioDevice for Controller {
 
         let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed");
 
-        if let Err(e) = self.socket.send(MsgOnSocketRequest::Activate {
+        if let Err(e) = self.socket.send(Request::Activate {
             shm: MaybeOwnedFd::new_borrowed(&mem),
             interrupt: MaybeOwnedFd::new_borrowed(&theirs),
             interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()),
@@ -461,13 +1150,13 @@ impl VirtioDevice for Controller {
     }
 
     fn reset(&mut self) -> bool {
-        if let Err(e) = self.socket.send(MsgOnSocketRequest::Reset) {
+        if let Err(e) = self.socket.send(Request::Reset) {
             error!("failed to send Reset: {}", e);
             return false;
         }
 
-        match self.socket.recv_msg_on_socket() {
-            Ok(MsgOnSocketResponse::Reset(result)) => result,
+        match self.socket.recv() {
+            Ok(Response::Reset(result)) => result,
             response => {
                 error!("bad response to Reset: {:?}", response);
                 false
@@ -476,12 +1165,12 @@ impl VirtioDevice for Controller {
     }
 
     fn get_device_bars(&mut self, address: PciAddress) -> Vec<PciBarConfiguration> {
-        if let Err(e) = self.socket.send(BincodeRequest::GetDeviceBars(address)) {
+        if let Err(e) = self.socket.send(Request::GetDeviceBars(address)) {
             panic!("failed to send GetDeviceBars: {}", e);
         }
 
-        match self.socket.recv_bincode() {
-            Ok(BincodeResponse::GetDeviceBars(bars)) => bars,
+        match self.socket.recv() {
+            Ok(Response::GetDeviceBars(bars)) => bars,
             response => {
                 panic!("bad response to GetDeviceBars: {:?}", response);
             }
@@ -489,12 +1178,12 @@ impl VirtioDevice for Controller {
     }
 
     fn get_device_caps(&self) -> Vec<Box<dyn PciCapability>> {
-        if let Err(e) = self.socket.send(BincodeRequest::GetDeviceCaps) {
+        if let Err(e) = self.socket.send(Request::GetDeviceCaps) {
             panic!("failed to send GetDeviceCaps: {}", e);
         }
 
-        match self.socket.recv_bincode() {
-            Ok(BincodeResponse::GetDeviceCaps(caps)) => caps
+        match self.socket.recv() {
+            Ok(Response::GetDeviceCaps(caps)) => caps
                 .into_iter()
                 .map(|cap| Box::new(cap) as Box<dyn PciCapability>)
                 .collect(),