From 8214c4c64fbdbf6ae84634bb822a90959271cad5 Mon Sep 17 00:00:00 2001 From: Alyssa Ross Date: Fri, 20 Mar 2020 05:48:28 +0000 Subject: msg_socket2: initial commit --- msg_socket2/tests/round_trip.rs | 99 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 msg_socket2/tests/round_trip.rs (limited to 'msg_socket2/tests') diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs new file mode 100644 index 0000000..08e1aff --- /dev/null +++ b/msg_socket2/tests/round_trip.rs @@ -0,0 +1,99 @@ +use std::os::unix::prelude::*; + +use std::fmt::{self, Formatter}; +use std::marker::PhantomData; +use std::mem::size_of; + +use msg_socket2::*; +use serde::de::*; +use serde::ser::*; +use sys_util::net::UnixSeqpacket; + +#[derive(Debug)] +struct Inner(RawFd, u16); + +#[derive(Debug)] +struct Test { + fd: RawFd, + inner: Inner, +} + +impl SerializeWithFds for Test { + fn serialize(&self, serializer: SerializerWithFds) -> Result + where + Ser: Serializer, + { + let mut state = serializer + .serializer + .serialize_struct("Test", size_of::())?; + serializer.fds.push(self.fd); + state.skip_field("fd")?; + + struct SerializableInner<'a>(&'a Inner); + + impl<'a> Serialize for SerializableInner<'a> { + fn serialize(&self, serializer: S) -> Result { + let mut state = serializer + .serialize_tuple_struct("Inner", size_of::() - size_of::() * 1)?; + state.serialize_field(&(self.0).1)?; + state.end() + } + } + + serializer.fds.push(self.inner.0); + state.serialize_field("inner", &SerializableInner(&self.inner))?; + + state.end() + } +} + +impl<'de> DeserializeWithFds<'de> for Test { + fn deserialize(deserializer: DeserializerWithFds) -> Result + where + I: Iterator, + De: Deserializer<'de>, + { + struct Visitor<'iter, 'de, Iter>(&'iter mut Iter, PhantomData<&'de ()>); + + impl<'iter, 'de, Iter: Iterator> serde::de::Visitor<'de> + for Visitor<'iter, 'de, Iter> + { + type Value = Test; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct Test") + } + + fn visit_seq>(self, mut seq: A) -> Result { + Ok(Test { + fd: self.0.next().unwrap(), + inner: Inner(self.0.next().unwrap(), seq.next_element()?.unwrap()), + }) + } + } + + let DeserializerWithFds { + mut fds, + deserializer, + } = deserializer; + + let visitor = Visitor(&mut fds, PhantomData); + deserializer.deserialize_struct("Test", &["fd", "inner"], visitor) + } +} + +#[test] +fn round_trip() { + let (f1, f2) = UnixSeqpacket::pair().unwrap(); + let s1: Socket<_, ()> = Socket::new(f1); + let s2: Socket<(), Test> = Socket::new(f2); + + s1.send(Test { + fd: 0, + inner: Inner(1, 0xACAB), + }) + .unwrap(); + + let result = s2.recv().unwrap(); + assert_eq!(result.inner.1, 0xACAB); +} -- cgit 1.4.1