summary refs log tree commit diff
path: root/msg_socket2/tests
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-20 05:48:28 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:05 +0000
commit8214c4c64fbdbf6ae84634bb822a90959271cad5 (patch)
tree6d46db38cb233ae7a7cf592b485608af96accf12 /msg_socket2/tests
parentb76f0d1043ffde3c6525abaecb421c0a4dc4c277 (diff)
downloadcrosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.gz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.bz2
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.lz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.xz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.zst
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.zip
msg_socket2: initial commit
Diffstat (limited to 'msg_socket2/tests')
-rw-r--r--msg_socket2/tests/round_trip.rs99
1 files changed, 99 insertions, 0 deletions
diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs
new file mode 100644
index 0000000..08e1aff
--- /dev/null
+++ b/msg_socket2/tests/round_trip.rs
@@ -0,0 +1,99 @@
+use std::os::unix::prelude::*;
+
+use std::fmt::{self, Formatter};
+use std::marker::PhantomData;
+use std::mem::size_of;
+
+use msg_socket2::*;
+use serde::de::*;
+use serde::ser::*;
+use sys_util::net::UnixSeqpacket;
+
+#[derive(Debug)]
+struct Inner(RawFd, u16);
+
+#[derive(Debug)]
+struct Test {
+    fd: RawFd,
+    inner: Inner,
+}
+
+impl SerializeWithFds for Test {
+    fn serialize<Ser>(&self, serializer: SerializerWithFds<Ser>) -> Result<Ser::Ok, Ser::Error>
+    where
+        Ser: Serializer,
+    {
+        let mut state = serializer
+            .serializer
+            .serialize_struct("Test", size_of::<Test>())?;
+        serializer.fds.push(self.fd);
+        state.skip_field("fd")?;
+
+        struct SerializableInner<'a>(&'a Inner);
+
+        impl<'a> Serialize for SerializableInner<'a> {
+            fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+                let mut state = serializer
+                    .serialize_tuple_struct("Inner", size_of::<Inner>() - size_of::<RawFd>() * 1)?;
+                state.serialize_field(&(self.0).1)?;
+                state.end()
+            }
+        }
+
+        serializer.fds.push(self.inner.0);
+        state.serialize_field("inner", &SerializableInner(&self.inner))?;
+
+        state.end()
+    }
+}
+
+impl<'de> DeserializeWithFds<'de> for Test {
+    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    where
+        I: Iterator<Item = RawFd>,
+        De: Deserializer<'de>,
+    {
+        struct Visitor<'iter, 'de, Iter>(&'iter mut Iter, PhantomData<&'de ()>);
+
+        impl<'iter, 'de, Iter: Iterator<Item = RawFd>> serde::de::Visitor<'de>
+            for Visitor<'iter, 'de, Iter>
+        {
+            type Value = Test;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "struct Test")
+            }
+
+            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
+                Ok(Test {
+                    fd: self.0.next().unwrap(),
+                    inner: Inner(self.0.next().unwrap(), seq.next_element()?.unwrap()),
+                })
+            }
+        }
+
+        let DeserializerWithFds {
+            mut fds,
+            deserializer,
+        } = deserializer;
+
+        let visitor = Visitor(&mut fds, PhantomData);
+        deserializer.deserialize_struct("Test", &["fd", "inner"], visitor)
+    }
+}
+
+#[test]
+fn round_trip() {
+    let (f1, f2) = UnixSeqpacket::pair().unwrap();
+    let s1: Socket<_, ()> = Socket::new(f1);
+    let s2: Socket<(), Test> = Socket::new(f2);
+
+    s1.send(Test {
+        fd: 0,
+        inner: Inner(1, 0xACAB),
+    })
+    .unwrap();
+
+    let result = s2.recv().unwrap();
+    assert_eq!(result.inner.1, 0xACAB);
+}