summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-21 13:50:24 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:08 +0000
commit81f96554ebf490c83a8993065de9b1636b26f916 (patch)
tree43f9efe6eb507b449918b510b7d7a72056e053ab
parent4d2e22e374e8ac93be227f0357efd2c0d7a9f699 (diff)
downloadcrosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar.gz
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar.bz2
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar.lz
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar.xz
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.tar.zst
crosvm-81f96554ebf490c83a8993065de9b1636b26f916.zip
SerializerWithFds trait
-rw-r--r--devices/src/virtio/controller.rs92
-rw-r--r--msg_socket2/src/lib.rs2
-rw-r--r--msg_socket2/src/ser.rs34
-rw-r--r--msg_socket2/src/socket.rs6
-rw-r--r--msg_socket2/tests/round_trip.rs17
5 files changed, 87 insertions, 64 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index 9b0784c..0b10aac 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -126,10 +126,10 @@ pub enum Request {
 }
 
 impl SerializeWithFds for Request {
-    fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
+    fn serialize<S: SerializerWithFds>(
+        &self,
+        mut serializer: S,
+    ) -> Result<<S::Ser as Serializer>::Ok, <S::Ser as Serializer>::Error> {
         use Request::*;
 
         match self {
@@ -137,37 +137,41 @@ impl SerializeWithFds for Request {
                 vm_socket,
                 memory_params,
             } => {
+                serializer.fds().push(vm_socket.as_raw_fd());
+
                 let mut sv = serializer
-                    .serializer
+                    .serializer()
                     .serialize_struct_variant("Request", 0, "Create", 1)?;
 
                 sv.skip_field("vm_socket")?;
-                serializer.fds.push(vm_socket.as_raw_fd());
-
                 sv.serialize_field("memory_params", memory_params)?;
 
                 sv.end()
             }
 
-            DebugLabel => serializer
-                .serializer
-                .serialize_unit_variant("Request", 1, "DebugLabel"),
+            DebugLabel => {
+                serializer
+                    .serializer()
+                    .serialize_unit_variant("Request", 1, "DebugLabel")
+            }
 
-            DeviceType => serializer
-                .serializer
-                .serialize_unit_variant("Request", 2, "DeviceType"),
+            DeviceType => {
+                serializer
+                    .serializer()
+                    .serialize_unit_variant("Request", 2, "DeviceType")
+            }
             QueueMaxSizes => {
                 serializer
-                    .serializer
+                    .serializer()
                     .serialize_unit_variant("Request", 3, "QueueMaxSizes")
             }
 
             Features => serializer
-                .serializer
+                .serializer()
                 .serialize_unit_variant("Request", 4, "Features"),
 
             AckFeatures(features) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Request",
                     5,
                     "AckFeatures",
@@ -178,7 +182,7 @@ impl SerializeWithFds for Request {
             }
 
             ReadConfig { offset, len } => {
-                let mut sv = serializer.serializer.serialize_struct_variant(
+                let mut sv = serializer.serializer().serialize_struct_variant(
                     "Request",
                     6,
                     "ReadConfig",
@@ -190,7 +194,7 @@ impl SerializeWithFds for Request {
             }
 
             WriteConfig { offset, data } => {
-                let mut sv = serializer.serializer.serialize_struct_variant(
+                let mut sv = serializer.serializer().serialize_struct_variant(
                     "Request",
                     7,
                     "WriteConfig",
@@ -210,37 +214,33 @@ impl SerializeWithFds for Request {
                 in_queue_evt,
                 out_queue_evt,
             } => {
+                serializer.fds().push(shm.as_raw_fd());
+                serializer.fds().push(interrupt.as_raw_fd());
+                serializer.fds().push(interrupt_resample_evt.as_raw_fd());
+                serializer.fds().push(in_queue_evt.as_raw_fd());
+                serializer.fds().push(out_queue_evt.as_raw_fd());
+
                 let mut sv = serializer
-                    .serializer
+                    .serializer()
                     .serialize_struct_variant("Request", 8, "Activate", 2)?;
 
                 sv.skip_field("shm")?;
-                serializer.fds.push(shm.as_raw_fd());
-
                 sv.skip_field("interrupt")?;
-                serializer.fds.push(interrupt.as_raw_fd());
-
                 sv.skip_field("interrupt_resample_evt")?;
-                serializer.fds.push(interrupt_resample_evt.as_raw_fd());
-
                 sv.serialize_field("in_queue", in_queue)?;
                 sv.serialize_field("out_queue", out_queue)?;
-
                 sv.skip_field("in_queue_evt")?;
-                serializer.fds.push(in_queue_evt.as_raw_fd());
-
                 sv.skip_field("out_queue_evt")?;
-                serializer.fds.push(out_queue_evt.as_raw_fd());
 
                 sv.end()
             }
 
             Reset => serializer
-                .serializer
+                .serializer()
                 .serialize_unit_variant("Request", 9, "Reset"),
 
             GetDeviceBars(address) => {
-                let mut sv = serializer.serializer.serialize_struct_variant(
+                let mut sv = serializer.serializer().serialize_struct_variant(
                     "Request",
                     10,
                     "GetDeviceBars",
@@ -252,12 +252,12 @@ impl SerializeWithFds for Request {
 
             GetDeviceCaps => {
                 serializer
-                    .serializer
+                    .serializer()
                     .serialize_unit_variant("Request", 11, "GetDeviceCaps")
             }
 
             Kill => serializer
-                .serializer
+                .serializer()
                 .serialize_unit_variant("Request", 12, "Kill"),
         }
     }
@@ -674,15 +674,15 @@ pub enum Response {
 }
 
 impl SerializeWithFds for Response {
-    fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
+    fn serialize<S: SerializerWithFds>(
+        &self,
+        serializer: S,
+    ) -> Result<<S::Ser as Serializer>::Ok, <S::Ser as Serializer>::Error> {
         use Response::*;
 
         match self {
             DebugLabel(label) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     0,
                     "DebugLabel",
@@ -693,7 +693,7 @@ impl SerializeWithFds for Response {
             }
 
             DeviceType(device_type) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     1,
                     "DeviceType",
@@ -704,7 +704,7 @@ impl SerializeWithFds for Response {
             }
 
             QueueMaxSizes(sizes) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     2,
                     "QueueMaxSizes",
@@ -716,14 +716,14 @@ impl SerializeWithFds for Response {
 
             Features(features) => {
                 let mut tv = serializer
-                    .serializer
+                    .serializer()
                     .serialize_tuple_variant("Response", 3, "Features", 1)?;
                 tv.serialize_field(features)?;
                 tv.end()
             }
 
             ReadConfig(config) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     4,
                     "ReadConfig",
@@ -735,14 +735,14 @@ impl SerializeWithFds for Response {
 
             Reset(success) => {
                 let mut tv = serializer
-                    .serializer
+                    .serializer()
                     .serialize_tuple_variant("Response", 5, "Reset", 1)?;
                 tv.serialize_field(success)?;
                 tv.end()
             }
 
             GetDeviceBars(bars) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     6,
                     "GetDeviceBars",
@@ -753,7 +753,7 @@ impl SerializeWithFds for Response {
             }
 
             GetDeviceCaps(caps) => {
-                let mut tv = serializer.serializer.serialize_tuple_variant(
+                let mut tv = serializer.serializer().serialize_tuple_variant(
                     "Response",
                     7,
                     "GetDeviceCaps",
@@ -764,7 +764,7 @@ impl SerializeWithFds for Response {
             }
 
             Kill => serializer
-                .serializer
+                .serializer()
                 .serialize_unit_variant("Response", 8, "Kill"),
         }
     }
diff --git a/msg_socket2/src/lib.rs b/msg_socket2/src/lib.rs
index 748a9f7..a1d8ceb 100644
--- a/msg_socket2/src/lib.rs
+++ b/msg_socket2/src/lib.rs
@@ -28,6 +28,8 @@ mod error;
 mod ser;
 mod socket;
 
+pub(crate) use ser::SerializerWithFdsImpl;
+
 pub use de::{DeserializeWithFds, DeserializerWithFds};
 pub use error::Error;
 pub use ser::{SerializeWithFds, SerializerWithFds};
diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs
index 0a60ea8..7bffa22 100644
--- a/msg_socket2/src/ser.rs
+++ b/msg_socket2/src/ser.rs
@@ -1,20 +1,40 @@
 use serde::Serializer;
 use std::os::unix::prelude::*;
 
+pub trait SerializerWithFds {
+    type Ser: Serializer;
+
+    fn serializer(self) -> Self::Ser;
+    fn fds(&mut self) -> &mut Vec<RawFd>;
+}
+
 pub trait SerializeWithFds {
-    fn serialize<Ser>(&self, serializer: SerializerWithFds<Ser>) -> Result<Ser::Ok, Ser::Error>
-    where
-        Ser: Serializer;
+    fn serialize<Ser: SerializerWithFds>(
+        &self,
+        serializer: Ser,
+    ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error>;
 }
 
 #[derive(Debug)]
-pub struct SerializerWithFds<'fds, Ser> {
-    pub serializer: Ser,
-    pub fds: &'fds mut Vec<RawFd>,
+pub struct SerializerWithFdsImpl<'fds, Ser> {
+    serializer: Ser,
+    fds: &'fds mut Vec<RawFd>,
 }
 
-impl<'fds, Ser> SerializerWithFds<'fds, Ser> {
+impl<'fds, Ser> SerializerWithFdsImpl<'fds, Ser> {
     pub fn new(fds: &'fds mut Vec<RawFd>, serializer: Ser) -> Self {
         Self { serializer, fds }
     }
 }
+
+impl<'fds, Ser: Serializer> SerializerWithFds for SerializerWithFdsImpl<'fds, Ser> {
+    type Ser = Ser;
+
+    fn serializer(self) -> Self::Ser {
+        self.serializer
+    }
+
+    fn fds(&mut self) -> &mut Vec<RawFd> {
+        &mut self.fds
+    }
+}
diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs
index dc0733d..4e75e82 100644
--- a/msg_socket2/src/socket.rs
+++ b/msg_socket2/src/socket.rs
@@ -4,7 +4,9 @@ use std::marker::PhantomData;
 use std::os::unix::prelude::*;
 use sys_util::{net::UnixSeqpacket, ScmSocket};
 
-use crate::{DeserializeWithFds, DeserializerWithFds, Error, SerializeWithFds, SerializerWithFds};
+use crate::{
+    DeserializeWithFds, DeserializerWithFds, Error, SerializeWithFds, SerializerWithFdsImpl,
+};
 
 #[derive(Debug)]
 pub struct Socket<Send, Recv> {
@@ -27,7 +29,7 @@ impl<Send: SerializeWithFds, Recv> Socket<Send, Recv> {
         let mut fds: Vec<RawFd> = vec![];
 
         let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new());
-        let serializer_with_fds = SerializerWithFds::new(&mut fds, &mut serializer);
+        let serializer_with_fds = SerializerWithFdsImpl::new(&mut fds, &mut serializer);
         value.serialize(serializer_with_fds)?;
 
         self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?;
diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs
index 1c6b5a3..efec94b 100644
--- a/msg_socket2/tests/round_trip.rs
+++ b/msg_socket2/tests/round_trip.rs
@@ -18,13 +18,12 @@ struct Test {
 }
 
 impl SerializeWithFds for Test {
-    fn serialize<Ser>(&self, serializer: SerializerWithFds<Ser>) -> Result<Ser::Ok, Ser::Error>
-    where
-        Ser: Serializer,
-    {
-        let mut state = serializer.serializer.serialize_struct("Test", 1)?;
-        serializer.fds.push(self.fd);
-        state.skip_field("fd")?;
+    fn serialize<Ser: SerializerWithFds>(
+        &self,
+        mut serializer: Ser,
+    ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error> {
+        serializer.fds().push(self.fd);
+        serializer.fds().push(self.inner.0);
 
         struct SerializableInner<'a>(&'a Inner);
 
@@ -36,9 +35,9 @@ impl SerializeWithFds for Test {
             }
         }
 
-        serializer.fds.push(self.inner.0);
+        let mut state = serializer.serializer().serialize_struct("Test", 1)?;
+        state.skip_field("fd")?;
         state.serialize_field("inner", &SerializableInner(&self.inner))?;
-
         state.end()
     }
 }