summary refs log tree commit diff
path: root/msg_socket2/tests
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-25 08:38:01 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:12 +0000
commitb6549a605935e29ab0ae4291737f8b0158bca1fb (patch)
tree7f4242993ce003cb787b242a264e3b8ea47e3430 /msg_socket2/tests
parent2885f9ca1a79d30421deeb025e92ae0118fc6d3a (diff)
downloadcrosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.gz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.bz2
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.lz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.xz
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.tar.zst
crosvm-b6549a605935e29ab0ae4291737f8b0158bca1fb.zip
recursive deserialization
Diffstat (limited to 'msg_socket2/tests')
-rw-r--r--msg_socket2/tests/round_trip.rs79
1 files changed, 50 insertions, 29 deletions
diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs
index 1bb6636..a89f414 100644
--- a/msg_socket2/tests/round_trip.rs
+++ b/msg_socket2/tests/round_trip.rs
@@ -1,20 +1,18 @@
-use std::os::unix::prelude::*;
-
 use std::fmt::{self, Formatter};
-use std::marker::PhantomData;
+use std::fs::File;
 
 use msg_socket2::{
+    de::{SeqAccessWithFds, VisitorWithFds},
     ser::{
-        SerializeAdapter, SerializeRawFd, SerializeStruct, SerializeStructFds,
-        SerializeTupleStruct, SerializeTupleStructFds,
+        SerializeAdapter, SerializeStruct, SerializeStructFds, SerializeTupleStruct,
+        SerializeTupleStructFds,
     },
     DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds, Serializer, Socket,
 };
-use serde::de::{Deserializer, SeqAccess};
 use sys_util::net::UnixSeqpacket;
 
 #[derive(Debug)]
-struct Inner(RawFd, u16);
+struct Inner(File, u16);
 
 impl SerializeWithFds for Inner {
     fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
@@ -25,14 +23,43 @@ impl SerializeWithFds for Inner {
 
     fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
         let mut state = serializer.serialize_tuple_struct("Inner", 1)?;
-        state.serialize_field(&SerializeRawFd::new(&self.0))?;
+        state.serialize_field(&self.0)?;
         state.end()
     }
 }
 
+impl<'de> DeserializeWithFds<'de> for Inner {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: DeserializerWithFds<'de>,
+    {
+        struct Visitor;
+
+        impl<'de> VisitorWithFds<'de> for Visitor {
+            type Value = Inner;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "struct Inner")
+            }
+
+            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+            where
+                A: SeqAccessWithFds<'de>,
+            {
+                Ok(Inner(
+                    seq.next_element()?.unwrap(),
+                    seq.next_element()?.unwrap(),
+                ))
+            }
+        }
+
+        deserializer.deserialize_tuple_struct("Inner", 2, Visitor)
+    }
+}
+
 #[derive(Debug)]
 struct Test {
-    fd: RawFd,
+    fd: File,
     inner: Inner,
 }
 
@@ -46,44 +73,38 @@ impl SerializeWithFds for Test {
 
     fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
         let mut state = serializer.serialize_struct("Test", 2)?;
-        state.serialize_field("fd", &SerializeRawFd::new(&self.fd))?;
+        state.serialize_field("fd", &self.fd)?;
         state.serialize_field("inner", &self.inner)?;
         state.end()
     }
 }
 
 impl<'de> DeserializeWithFds<'de> for Test {
-    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     where
-        I: Iterator<Item = RawFd>,
-        De: Deserializer<'de>,
+        D: DeserializerWithFds<'de>,
     {
-        struct Visitor<'iter, 'de, Iter>(&'iter mut Iter, PhantomData<&'de ()>);
+        struct Visitor;
 
-        impl<'iter, 'de, Iter: Iterator<Item = RawFd>> serde::de::Visitor<'de>
-            for Visitor<'iter, 'de, Iter>
-        {
+        impl<'de> VisitorWithFds<'de> for Visitor {
             type Value = Test;
 
             fn expecting(&self, f: &mut Formatter) -> fmt::Result {
                 write!(f, "struct Test")
             }
 
-            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
+            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+            where
+                A: SeqAccessWithFds<'de>,
+            {
                 Ok(Test {
-                    fd: self.0.next().unwrap(),
-                    inner: Inner(self.0.next().unwrap(), seq.next_element()?.unwrap()),
+                    fd: seq.next_element()?.unwrap(),
+                    inner: seq.next_element()?.unwrap(),
                 })
             }
         }
 
-        let DeserializerWithFds {
-            mut fds,
-            deserializer,
-        } = deserializer;
-
-        let visitor = Visitor(&mut fds, PhantomData);
-        deserializer.deserialize_struct("Test", &["fd", "inner"], visitor)
+        deserializer.deserialize_struct("Test", &["fd", "inner"], Visitor)
     }
 }
 
@@ -94,8 +115,8 @@ fn round_trip() {
     let s2: Socket<(), Test> = Socket::new(f2);
 
     s1.send(Test {
-        fd: 0,
-        inner: Inner(1, 0xACAB),
+        fd: File::open("/dev/null").unwrap(),
+        inner: Inner(File::open("/dev/null").unwrap(), 0xACAB),
     })
     .unwrap();