summary refs log tree commit diff
path: root/devices
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-25 08:38:01 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:12 +0000
commitb6549a605935e29ab0ae4291737f8b0158bca1fb (patch)
tree7f4242993ce003cb787b242a264e3b8ea47e3430 /devices
parent2885f9ca1a79d30421deeb025e92ae0118fc6d3a (diff)
downloadcrosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.gz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.bz2
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.lz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.xz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.zst
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.zip
recursive deserialization
Diffstat (limited to 'devices')
-rw-r--r--devices/Cargo.toml1
-rw-r--r--devices/src/lib.rs7
-rw-r--r--devices/src/pci/pci_root.rs15
-rw-r--r--devices/src/virtio/controller.rs316
-rw-r--r--devices/src/virtio/queue.rs5
5 files changed, 119 insertions, 225 deletions
diff --git a/devices/Cargo.toml b/devices/Cargo.toml
index 4d6c4d8..830cb86 100644
--- a/devices/Cargo.toml
+++ b/devices/Cargo.toml
@@ -35,7 +35,6 @@ linux_input_sys = { path = "../linux_input_sys" }
 msg_on_socket_derive = { path = "../msg_socket/msg_on_socket_derive" }
 msg_socket = { path = "../msg_socket" }
 msg_socket2 = { path = "../msg_socket2" }
-msg_socket2_derive = { path = "../msg_socket2/derive" }
 net_sys = { path = "../net_sys" }
 net_util = { path = "../net_util" }
 p9 = { path = "../p9" }
diff --git a/devices/src/lib.rs b/devices/src/lib.rs
index 7df9c62..ce57e61 100644
--- a/devices/src/lib.rs
+++ b/devices/src/lib.rs
@@ -49,12 +49,9 @@ pub use self::usb::xhci::xhci_controller::XhciController;
 pub use self::vfio::{VfioContainer, VfioDevice};
 pub use self::virtio::VirtioPciDevice;
 
-use msg_socket::MsgOnSocket;
-use serde::{Deserialize, Serialize};
+use msg_socket2::{Deserialize, DeserializeWithFds, Serialize, SerializeWithFds};
 
-use msg_socket2_derive::SerializeWithFds;
-
-#[derive(Clone, Copy, Debug, MsgOnSocket, Serialize, SerializeWithFds, Deserialize)]
+#[derive(Clone, Copy, Debug, Serialize, SerializeWithFds, Deserialize, DeserializeWithFds)]
 #[msg_socket2(strategy = "serde")]
 pub struct MemoryParams {
     /// Physical memory size in bytes for the VM.
diff --git a/devices/src/pci/pci_root.rs b/devices/src/pci/pci_root.rs
index 76f9d82..d100941 100644
--- a/devices/src/pci/pci_root.rs
+++ b/devices/src/pci/pci_root.rs
@@ -8,8 +8,7 @@ use std::fmt::{self, Display};
 use std::os::unix::io::RawFd;
 use std::sync::Arc;
 
-use msg_socket2_derive::SerializeWithFds;
-use serde::{Deserialize, Serialize};
+use msg_socket2::{Deserialize, DeserializeWithFds, Serialize, SerializeWithFds};
 use sync::Mutex;
 
 use crate::pci::pci_configuration::{
@@ -45,7 +44,17 @@ impl PciDevice for PciRootConfiguration {
 
 /// PCI Device Address, AKA Bus:Device.Function
 #[derive(
-    Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize, SerializeWithFds,
+    Clone,
+    Copy,
+    Debug,
+    Deserialize,
+    DeserializeWithFds,
+    Eq,
+    Ord,
+    PartialEq,
+    PartialOrd,
+    Serialize,
+    SerializeWithFds,
 )]
 #[msg_socket2(strategy = "serde")]
 pub struct PciAddress {
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 0a59072..644ece1 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -28,8 +28,6 @@
 //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
 //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
 
-use std::os::unix::prelude::*;
-
 use std::collections::BTreeMap as Map;
 use std::fmt::{self, Formatter};
 use std::os::unix::io::{AsRawFd, RawFd};
@@ -38,10 +36,9 @@ use std::sync::Arc;
 use std::thread;
 
 use msg_socket::{MsgReceiver, MsgSocket};
+use msg_socket2::de::{EnumAccessWithFds, SeqAccessWithFds, VariantAccessWithFds, VisitorWithFds};
 use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds};
 use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds};
-use msg_socket2_derive::SerializeWithFds;
-use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess};
 use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer};
 use serde::{Deserialize, Serialize};
 use sys_util::net::UnixSeqpacket;
@@ -136,11 +133,9 @@ impl SerializeWithFds for Request {
                 vm_socket: _,
                 memory_params,
             } => {
-                let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 1)?;
-
-                sv.skip_field("vm_socket")?;
+                let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?;
+                sv.serialize_field("vm_socket", &())?;
                 sv.serialize_field("memory_params", memory_params)?;
-
                 sv.end()
             }
 
@@ -180,16 +175,14 @@ impl SerializeWithFds for Request {
                 in_queue_evt: _,
                 out_queue_evt: _,
             } => {
-                let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 2)?;
-
-                sv.skip_field("shm")?;
-                sv.skip_field("interrupt")?;
-                sv.skip_field("interrupt_resample_evt")?;
+                let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 7)?;
+                sv.serialize_field("shm", &())?;
+                sv.serialize_field("interrupt", &())?;
+                sv.serialize_field("interrupt_resample_evt", &())?;
                 sv.serialize_field("in_queue", in_queue)?;
                 sv.serialize_field("out_queue", out_queue)?;
-                sv.skip_field("in_queue_evt")?;
-                sv.skip_field("out_queue_evt")?;
-
+                sv.serialize_field("in_queue_evt", &())?;
+                sv.serialize_field("out_queue_evt", &())?;
                 sv.end()
             }
 
@@ -284,26 +277,20 @@ impl SerializeWithFds for Request {
 }
 
 impl<'de> DeserializeWithFds<'de> for Request {
-    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
-    where
-        I: Iterator<Item = RawFd>,
-        De: Deserializer<'de>,
-    {
-        struct Visitor<'iter, Iter> {
-            fds: &'iter mut Iter,
-        }
-
-        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-        where
-            Iter: Iterator<Item = RawFd>,
-        {
+    fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
+        struct Visitor;
+
+        impl<'de> VisitorWithFds<'de> for Visitor {
             type Value = Request;
 
             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                 write!(f, "enum Request")
             }
 
-            fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
+            fn visit_enum<A: EnumAccessWithFds<'de>>(
+                self,
+                data: A,
+            ) -> Result<Self::Value, A::Error> {
                 #[derive(Debug, Deserialize)]
                 enum Variant {
                     Create,
@@ -323,54 +310,38 @@ impl<'de> DeserializeWithFds<'de> for Request {
 
                 match data.variant()? {
                     (Variant::Create, variant) => {
-                        struct Visitor<'iter, Iter> {
-                            fds: &'iter mut Iter,
-                        }
+                        struct Visitor;
 
-                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-                        where
-                            Iter: Iterator<Item = RawFd>,
-                        {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::Create")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::Create with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::Create {
-                                    vm_socket: match self.fds.next() {
-                                        Some(vm_socket) => MaybeOwnedFd::Owned(unsafe {
-                                            UnixSeqpacket::from_raw_fd(vm_socket)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                0,
-                                                &"struct variant Request::Create with 2 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    memory_params: match seq.next_element()? {
-                                        Some(memory_params) => memory_params,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                1,
-                                                &"struct variant Request::Create with 2 elements",
-                                            ))
-                                        }
-                                    },
+                                    vm_socket: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    memory_params: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
 
-                        let visitor = Visitor { fds: self.fds };
-                        variant.struct_variant(&["memory_params"], visitor)
+                        variant.struct_variant(&["vm_socket", "memory_params"], Visitor)
                     }
 
                     (Variant::DebugLabel, variant) => {
@@ -400,35 +371,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::ReadConfig, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::ReadConfig")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::ReadConfig with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::ReadConfig {
-                                    offset: match seq.next_element()? {
-                                        Some(offset) => offset,
-                                        None => return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::ReadConfig with 2 elements",
-                                        )),
-                                    },
-
-                                    len: match seq.next_element()? {
-                                        Some(len) => len,
-                                        None => return Err(Error::invalid_length(
-                                            1,
-                                            &"struct variant Request::ReadConfig with 2 elements",
-                                        )),
-                                    },
+                                    offset: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    len: seq.next_element()?.ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
@@ -439,35 +404,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::WriteConfig, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::WriteConfig")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::WriteConfig with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::WriteConfig {
-                                    offset: match seq.next_element()? {
-                                        Some(offset) => offset,
-                                        None => return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::WriteConfig with 2 elements",
-                                        )),
-                                    },
-
-                                    data: match seq.next_element()? {
-                                        Some(data) => data,
-                                        None => return Err(Error::invalid_length(
-                                            1,
-                                            &"struct variant Request::WriteConfig with 2 elements",
-                                        )),
-                                    },
+                                    offset: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    data: seq.next_element()?.ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
@@ -476,114 +435,58 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     }
 
                     (Variant::Activate, variant) => {
-                        struct Visitor<'iter, Iter> {
-                            fds: &'iter mut Iter,
-                        }
+                        struct Visitor;
 
-                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-                        where
-                            Iter: Iterator<Item = RawFd>,
-                        {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::Activate")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::Activate with 7 elements",
+                                    )
+                                }
+
                                 Ok(Request::Activate {
-                                    shm: match self.fds.next() {
-                                        Some(shm) => MaybeOwnedFd::Owned(unsafe {
-                                            FromRawFd::from_raw_fd(shm)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                0,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    interrupt: match self.fds.next() {
-                                        Some(interrupt) => MaybeOwnedFd::Owned(unsafe {
-                                            UnixSeqpacket::from_raw_fd(interrupt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                1,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    interrupt_resample_evt: match self.fds.next() {
-                                        Some(interrupt_resample_evt) => {
-                                            MaybeOwnedFd::Owned(unsafe {
-                                                EventFd::from_raw_fd(interrupt_resample_evt)
-                                            })
-                                        }
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                2,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    in_queue: match seq.next_element()? {
-                                        Some(in_queue) => in_queue,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                3,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    out_queue: match seq.next_element()? {
-                                        Some(out_queue) => out_queue,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                4,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    in_queue_evt: match self.fds.next() {
-                                        Some(in_queue_evt) => MaybeOwnedFd::Owned(unsafe {
-                                            EventFd::from_raw_fd(in_queue_evt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                5,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    out_queue_evt: match self.fds.next() {
-                                        Some(out_queue_evt) => MaybeOwnedFd::Owned(unsafe {
-                                            EventFd::from_raw_fd(out_queue_evt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                6,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
+                                    shm: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    interrupt: seq.next_element()?.ok_or_else(|| too_short(1))?,
+                                    interrupt_resample_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(2))?,
+                                    in_queue: seq.next_element()?.ok_or_else(|| too_short(3))?,
+                                    out_queue: seq.next_element()?.ok_or_else(|| too_short(4))?,
+                                    in_queue_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(5))?,
+                                    out_queue_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(6))?,
                                 })
                             }
                         }
 
-                        let visitor = Visitor { fds: self.fds };
-                        variant.struct_variant(&["in_queue", "out_queue"], visitor)
+                        variant.struct_variant(
+                            &[
+                                "shm",
+                                "interrupt",
+                                "interrupt_resample_evt",
+                                "in_queue",
+                                "out_queue",
+                                "in_queue_evt",
+                                "out_queue_evt",
+                            ],
+                            Visitor,
+                        )
                     }
 
                     (Variant::Reset, variant) => {
@@ -594,28 +497,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::GetDeviceBars, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::GetDeviceBars")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
-                                Ok(Request::GetDeviceBars(match seq.next_element()? {
-                                    Some(address) => address,
-                                    None => {
-                                        return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::GetDeviceBars with 1 element",
-                                        ))
-                                    }
-                                }))
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::GetDeviceBars with 2 elements",
+                                    )
+                                }
+
+                                Ok(Request::GetDeviceBars(
+                                    seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                ))
                             }
                         }
 
@@ -635,11 +539,6 @@ impl<'de> DeserializeWithFds<'de> for Request {
             }
         }
 
-        let DeserializerWithFds {
-            mut fds,
-            deserializer,
-        } = deserializer;
-        let visitor = Visitor { fds: &mut fds };
         deserializer.deserialize_enum(
             "Request",
             &[
@@ -657,12 +556,12 @@ impl<'de> DeserializeWithFds<'de> for Request {
                 "GetDeviceCaps",
                 "Kill",
             ],
-            visitor,
+            Visitor,
         )
     }
 }
 
-#[derive(Debug, Deserialize, Serialize, SerializeWithFds)]
+#[derive(Debug, Deserialize, DeserializeWithFds, Serialize, SerializeWithFds)]
 #[msg_socket2(strategy = "serde")]
 pub enum Response {
     DebugLabel(String),
@@ -676,15 +575,6 @@ pub enum Response {
     Kill,
 }
 
-impl<'de> DeserializeWithFds<'de> for Response {
-    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
-    where
-        De: Deserializer<'de>,
-    {
-        Deserialize::deserialize(deserializer.deserializer)
-    }
-}
-
 type Socket = msg_socket2::Socket<Request, Response>;
 
 // TODO: support arbitrary number of queues
diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs
index f2310fa..c4c5217 100644
--- a/devices/src/virtio/queue.rs
+++ b/devices/src/virtio/queue.rs
@@ -6,8 +6,7 @@ use std::cmp::min;
 use std::num::Wrapping;
 use std::sync::atomic::{fence, Ordering};
 
-use msg_socket::MsgOnSocket;
-use msg_socket2_derive::SerializeWithFds;
+use msg_socket2::{DeserializeWithFds, SerializeWithFds};
 use sys_util::{error, GuestAddress, GuestMemory};
 use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
 
@@ -203,7 +202,7 @@ impl<'a, 'b> Iterator for AvailIter<'a, 'b> {
 
 use serde::{Deserialize, Serialize};
 
-#[derive(Clone, Debug, MsgOnSocket, Serialize, SerializeWithFds, Deserialize)]
+#[derive(Clone, Debug, Serialize, SerializeWithFds, Deserialize, DeserializeWithFds)]
 #[msg_socket2(strategy = "serde")]
 /// A virtio queue's parameters.
 pub struct Queue {