summary refs log tree commit diff
path: root/devices/src/virtio/controller.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/controller.rs')
-rw-r--r--devices/src/virtio/controller.rs316
1 files changed, 103 insertions, 213 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 0a59072..644ece1 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -28,8 +28,6 @@
 //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
 //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
 
-use std::os::unix::prelude::*;
-
 use std::collections::BTreeMap as Map;
 use std::fmt::{self, Formatter};
 use std::os::unix::io::{AsRawFd, RawFd};
@@ -38,10 +36,9 @@ use std::sync::Arc;
 use std::thread;
 
 use msg_socket::{MsgReceiver, MsgSocket};
+use msg_socket2::de::{EnumAccessWithFds, SeqAccessWithFds, VariantAccessWithFds, VisitorWithFds};
 use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds};
 use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds};
-use msg_socket2_derive::SerializeWithFds;
-use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess};
 use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer};
 use serde::{Deserialize, Serialize};
 use sys_util::net::UnixSeqpacket;
@@ -136,11 +133,9 @@ impl SerializeWithFds for Request {
                 vm_socket: _,
                 memory_params,
             } => {
-                let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 1)?;
-
-                sv.skip_field("vm_socket")?;
+                let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?;
+                sv.serialize_field("vm_socket", &())?;
                 sv.serialize_field("memory_params", memory_params)?;
-
                 sv.end()
             }
 
@@ -180,16 +175,14 @@ impl SerializeWithFds for Request {
                 in_queue_evt: _,
                 out_queue_evt: _,
             } => {
-                let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 2)?;
-
-                sv.skip_field("shm")?;
-                sv.skip_field("interrupt")?;
-                sv.skip_field("interrupt_resample_evt")?;
+                let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 7)?;
+                sv.serialize_field("shm", &())?;
+                sv.serialize_field("interrupt", &())?;
+                sv.serialize_field("interrupt_resample_evt", &())?;
                 sv.serialize_field("in_queue", in_queue)?;
                 sv.serialize_field("out_queue", out_queue)?;
-                sv.skip_field("in_queue_evt")?;
-                sv.skip_field("out_queue_evt")?;
-
+                sv.serialize_field("in_queue_evt", &())?;
+                sv.serialize_field("out_queue_evt", &())?;
                 sv.end()
             }
 
@@ -284,26 +277,20 @@ impl SerializeWithFds for Request {
 }
 
 impl<'de> DeserializeWithFds<'de> for Request {
-    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
-    where
-        I: Iterator<Item = RawFd>,
-        De: Deserializer<'de>,
-    {
-        struct Visitor<'iter, Iter> {
-            fds: &'iter mut Iter,
-        }
-
-        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-        where
-            Iter: Iterator<Item = RawFd>,
-        {
+    fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
+        struct Visitor;
+
+        impl<'de> VisitorWithFds<'de> for Visitor {
             type Value = Request;
 
             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                 write!(f, "enum Request")
             }
 
-            fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
+            fn visit_enum<A: EnumAccessWithFds<'de>>(
+                self,
+                data: A,
+            ) -> Result<Self::Value, A::Error> {
                 #[derive(Debug, Deserialize)]
                 enum Variant {
                     Create,
@@ -323,54 +310,38 @@ impl<'de> DeserializeWithFds<'de> for Request {
 
                 match data.variant()? {
                     (Variant::Create, variant) => {
-                        struct Visitor<'iter, Iter> {
-                            fds: &'iter mut Iter,
-                        }
+                        struct Visitor;
 
-                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-                        where
-                            Iter: Iterator<Item = RawFd>,
-                        {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::Create")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::Create with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::Create {
-                                    vm_socket: match self.fds.next() {
-                                        Some(vm_socket) => MaybeOwnedFd::Owned(unsafe {
-                                            UnixSeqpacket::from_raw_fd(vm_socket)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                0,
-                                                &"struct variant Request::Create with 2 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    memory_params: match seq.next_element()? {
-                                        Some(memory_params) => memory_params,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                1,
-                                                &"struct variant Request::Create with 2 elements",
-                                            ))
-                                        }
-                                    },
+                                    vm_socket: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    memory_params: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
 
-                        let visitor = Visitor { fds: self.fds };
-                        variant.struct_variant(&["memory_params"], visitor)
+                        variant.struct_variant(&["vm_socket", "memory_params"], Visitor)
                     }
 
                     (Variant::DebugLabel, variant) => {
@@ -400,35 +371,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::ReadConfig, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::ReadConfig")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::ReadConfig with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::ReadConfig {
-                                    offset: match seq.next_element()? {
-                                        Some(offset) => offset,
-                                        None => return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::ReadConfig with 2 elements",
-                                        )),
-                                    },
-
-                                    len: match seq.next_element()? {
-                                        Some(len) => len,
-                                        None => return Err(Error::invalid_length(
-                                            1,
-                                            &"struct variant Request::ReadConfig with 2 elements",
-                                        )),
-                                    },
+                                    offset: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    len: seq.next_element()?.ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
@@ -439,35 +404,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::WriteConfig, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::WriteConfig")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::WriteConfig with 2 elements",
+                                    )
+                                }
+
                                 Ok(Request::WriteConfig {
-                                    offset: match seq.next_element()? {
-                                        Some(offset) => offset,
-                                        None => return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::WriteConfig with 2 elements",
-                                        )),
-                                    },
-
-                                    data: match seq.next_element()? {
-                                        Some(data) => data,
-                                        None => return Err(Error::invalid_length(
-                                            1,
-                                            &"struct variant Request::WriteConfig with 2 elements",
-                                        )),
-                                    },
+                                    offset: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    data: seq.next_element()?.ok_or_else(|| too_short(1))?,
                                 })
                             }
                         }
@@ -476,114 +435,58 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     }
 
                     (Variant::Activate, variant) => {
-                        struct Visitor<'iter, Iter> {
-                            fds: &'iter mut Iter,
-                        }
+                        struct Visitor;
 
-                        impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter>
-                        where
-                            Iter: Iterator<Item = RawFd>,
-                        {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::Activate")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::Activate with 7 elements",
+                                    )
+                                }
+
                                 Ok(Request::Activate {
-                                    shm: match self.fds.next() {
-                                        Some(shm) => MaybeOwnedFd::Owned(unsafe {
-                                            FromRawFd::from_raw_fd(shm)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                0,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    interrupt: match self.fds.next() {
-                                        Some(interrupt) => MaybeOwnedFd::Owned(unsafe {
-                                            UnixSeqpacket::from_raw_fd(interrupt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                1,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    interrupt_resample_evt: match self.fds.next() {
-                                        Some(interrupt_resample_evt) => {
-                                            MaybeOwnedFd::Owned(unsafe {
-                                                EventFd::from_raw_fd(interrupt_resample_evt)
-                                            })
-                                        }
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                2,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    in_queue: match seq.next_element()? {
-                                        Some(in_queue) => in_queue,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                3,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    out_queue: match seq.next_element()? {
-                                        Some(out_queue) => out_queue,
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                4,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    in_queue_evt: match self.fds.next() {
-                                        Some(in_queue_evt) => MaybeOwnedFd::Owned(unsafe {
-                                            EventFd::from_raw_fd(in_queue_evt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                5,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
-
-                                    out_queue_evt: match self.fds.next() {
-                                        Some(out_queue_evt) => MaybeOwnedFd::Owned(unsafe {
-                                            EventFd::from_raw_fd(out_queue_evt)
-                                        }),
-                                        None => {
-                                            return Err(Error::invalid_length(
-                                                6,
-                                                &"struct variant Request::Activate with 7 elements",
-                                            ))
-                                        }
-                                    },
+                                    shm: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    interrupt: seq.next_element()?.ok_or_else(|| too_short(1))?,
+                                    interrupt_resample_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(2))?,
+                                    in_queue: seq.next_element()?.ok_or_else(|| too_short(3))?,
+                                    out_queue: seq.next_element()?.ok_or_else(|| too_short(4))?,
+                                    in_queue_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(5))?,
+                                    out_queue_evt: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(6))?,
                                 })
                             }
                         }
 
-                        let visitor = Visitor { fds: self.fds };
-                        variant.struct_variant(&["in_queue", "out_queue"], visitor)
+                        variant.struct_variant(
+                            &[
+                                "shm",
+                                "interrupt",
+                                "interrupt_resample_evt",
+                                "in_queue",
+                                "out_queue",
+                                "in_queue_evt",
+                                "out_queue_evt",
+                            ],
+                            Visitor,
+                        )
                     }
 
                     (Variant::Reset, variant) => {
@@ -594,28 +497,29 @@ impl<'de> DeserializeWithFds<'de> for Request {
                     (Variant::GetDeviceBars, variant) => {
                         struct Visitor;
 
-                        impl<'de> serde::de::Visitor<'de> for Visitor {
+                        impl<'de> VisitorWithFds<'de> for Visitor {
                             type Value = Request;
 
                             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                                 write!(f, "struct variant Request::GetDeviceBars")
                             }
 
-                            fn visit_seq<A: SeqAccess<'de>>(
+                            fn visit_seq<A: SeqAccessWithFds<'de>>(
                                 self,
                                 mut seq: A,
                             ) -> Result<Request, A::Error> {
                                 use serde::de::Error;
 
-                                Ok(Request::GetDeviceBars(match seq.next_element()? {
-                                    Some(address) => address,
-                                    None => {
-                                        return Err(Error::invalid_length(
-                                            0,
-                                            &"struct variant Request::GetDeviceBars with 1 element",
-                                        ))
-                                    }
-                                }))
+                                fn too_short<E: Error>(len: usize) -> E {
+                                    E::invalid_length(
+                                        len,
+                                        &"struct variant Request::GetDeviceBars with 2 elements",
+                                    )
+                                }
+
+                                Ok(Request::GetDeviceBars(
+                                    seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                ))
                             }
                         }
 
@@ -635,11 +539,6 @@ impl<'de> DeserializeWithFds<'de> for Request {
             }
         }
 
-        let DeserializerWithFds {
-            mut fds,
-            deserializer,
-        } = deserializer;
-        let visitor = Visitor { fds: &mut fds };
         deserializer.deserialize_enum(
             "Request",
             &[
@@ -657,12 +556,12 @@ impl<'de> DeserializeWithFds<'de> for Request {
                 "GetDeviceCaps",
                 "Kill",
             ],
-            visitor,
+            Visitor,
         )
     }
 }
 
-#[derive(Debug, Deserialize, Serialize, SerializeWithFds)]
+#[derive(Debug, Deserialize, DeserializeWithFds, Serialize, SerializeWithFds)]
 #[msg_socket2(strategy = "serde")]
 pub enum Response {
     DebugLabel(String),
@@ -676,15 +575,6 @@ pub enum Response {
     Kill,
 }
 
-impl<'de> DeserializeWithFds<'de> for Response {
-    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
-    where
-        De: Deserializer<'de>,
-    {
-        Deserialize::deserialize(deserializer.deserializer)
-    }
-}
-
 type Socket = msg_socket2::Socket<Request, Response>;
 
 // TODO: support arbitrary number of queues