From 8214c4c64fbdbf6ae84634bb822a90959271cad5 Mon Sep 17 00:00:00 2001 From: Alyssa Ross Date: Fri, 20 Mar 2020 05:48:28 +0000 Subject: msg_socket2: initial commit --- Cargo.lock | 8 ++-- msg_socket2/Cargo.toml | 11 +++++ msg_socket2/src/de.rs | 21 +++++++++ msg_socket2/src/error.rs | 23 ++++++++++ msg_socket2/src/lib.rs | 34 ++++++++++++++ msg_socket2/src/ser.rs | 20 +++++++++ msg_socket2/src/socket.rs | 48 ++++++++++++++++++++ msg_socket2/tests/round_trip.rs | 99 +++++++++++++++++++++++++++++++++++++++++ poly_msg_socket/Cargo.toml | 4 +- 9 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 msg_socket2/Cargo.toml create mode 100644 msg_socket2/src/de.rs create mode 100644 msg_socket2/src/error.rs create mode 100644 msg_socket2/src/lib.rs create mode 100644 msg_socket2/src/ser.rs create mode 100644 msg_socket2/src/socket.rs create mode 100644 msg_socket2/tests/round_trip.rs diff --git a/Cargo.lock b/Cargo.lock index 3db8e4e..3b175fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,8 +57,8 @@ dependencies = [ [[package]] name = "bincode" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" +version = "1.3.0" +source = "git+https://github.com/alyssais/bincode?branch=from_slice#ed85a66d69be7073f8b484fcc65101149bb31acc" dependencies = [ "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", @@ -574,7 +574,7 @@ dependencies = [ name = "poly_msg_socket" version = "0.1.0" dependencies = [ - "bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)", "msg_socket 0.1.0", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "sys_util 0.1.0", @@ -859,7 +859,7 @@ dependencies = [ ] [metadata] -"checksum bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "5753e2a71534719bf3f4e57006c3a4f0d2c672a4b676eec84161f763eca87dbf" +"checksum bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)" = "" "checksum bitflags 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3d155346769a6855b86399e9bc3814ab343cd3d62c7e985113d46a0ec3c281fd" "checksum byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de" "checksum cc 1.0.25 (registry+https://github.com/rust-lang/crates.io-index)" = "f159dfd43363c4d08055a07703eb7a3406b0dac4d0584d96965a3262db3c9d16" diff --git a/msg_socket2/Cargo.toml b/msg_socket2/Cargo.toml new file mode 100644 index 0000000..ba167d8 --- /dev/null +++ b/msg_socket2/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "msg_socket2" +version = "0.1.0" +authors = ["Alyssa Ross "] +edition = "2018" + +[dependencies] +serde = "1.0.104" +sys_util = { path = "../sys_util" } + +bincode = { git = "https://github.com/alyssais/bincode", branch = "from_slice" } diff --git a/msg_socket2/src/de.rs b/msg_socket2/src/de.rs new file mode 100644 index 0000000..1d3a9e1 --- /dev/null +++ b/msg_socket2/src/de.rs @@ -0,0 +1,21 @@ +use serde::Deserializer; +use std::os::unix::prelude::*; + +pub trait DeserializeWithFds<'de>: Sized { + fn deserialize(deserializer: DeserializerWithFds) -> Result + where + I: Iterator, + De: Deserializer<'de>; +} + +#[derive(Debug)] +pub struct DeserializerWithFds<'iter, Iter, De> { + pub deserializer: De, + pub fds: &'iter mut Iter, +} + +impl<'iter, Iter, De> DeserializerWithFds<'iter, Iter, De> { + pub fn new(fds: &'iter mut Iter, deserializer: De) -> Self { + Self { deserializer, fds } + } +} diff --git a/msg_socket2/src/error.rs b/msg_socket2/src/error.rs new file mode 100644 index 0000000..902684b --- /dev/null +++ b/msg_socket2/src/error.rs @@ -0,0 +1,23 @@ +#[derive(Debug)] +pub enum Error { + DataError(bincode::Error), + IoError(sys_util::Error), +} + +impl From for Error { + fn from(error: bincode::Error) -> Self { + Self::DataError(error) + } +} + +impl From for Error { + fn from(error: sys_util::Error) -> Self { + Self::IoError(error) + } +} + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Self::IoError(error.into()) + } +} diff --git a/msg_socket2/src/lib.rs b/msg_socket2/src/lib.rs new file mode 100644 index 0000000..748a9f7 --- /dev/null +++ b/msg_socket2/src/lib.rs @@ -0,0 +1,34 @@ +// Copyright 2020, Alyssa Ross +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of the nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod de; +mod error; +mod ser; +mod socket; + +pub use de::{DeserializeWithFds, DeserializerWithFds}; +pub use error::Error; +pub use ser::{SerializeWithFds, SerializerWithFds}; +pub use socket::Socket; diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs new file mode 100644 index 0000000..0a60ea8 --- /dev/null +++ b/msg_socket2/src/ser.rs @@ -0,0 +1,20 @@ +use serde::Serializer; +use std::os::unix::prelude::*; + +pub trait SerializeWithFds { + fn serialize(&self, serializer: SerializerWithFds) -> Result + where + Ser: Serializer; +} + +#[derive(Debug)] +pub struct SerializerWithFds<'fds, Ser> { + pub serializer: Ser, + pub fds: &'fds mut Vec, +} + +impl<'fds, Ser> SerializerWithFds<'fds, Ser> { + pub fn new(fds: &'fds mut Vec, serializer: Ser) -> Self { + Self { serializer, fds } + } +} diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs new file mode 100644 index 0000000..bce587a --- /dev/null +++ b/msg_socket2/src/socket.rs @@ -0,0 +1,48 @@ +use bincode::{DefaultOptions, Serializer, Deserializer}; +use std::marker::PhantomData; +use std::io::IoSlice; +use sys_util::{net::UnixSeqpacket, ScmSocket}; + +use crate::{DeserializerWithFds, DeserializeWithFds, Error, SerializeWithFds, SerializerWithFds}; + +#[derive(Debug)] +pub struct Socket { + sock: UnixSeqpacket, + __: PhantomData<(Send, Recv)>, +} + +impl Socket { + pub fn new(sock: UnixSeqpacket) -> Self { + Self { + sock, + __: PhantomData, + } + } +} + +impl Socket { + pub fn send(&self, value: Send) -> Result<(), Error> { + let mut bytes: Vec = vec![]; + let mut fds: Vec = vec![]; + + let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new()); + let serializer_with_fds = SerializerWithFds::new(&mut fds, &mut serializer); + value.serialize(serializer_with_fds)?; + + self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?; + + Ok(()) + } +} + +impl DeserializeWithFds<'de>> Socket { + pub fn recv(&self) -> Result { + let (bytes, fds) = self.sock.recv_as_vec_with_fds()?; + let mut fds_iter = fds.into_iter(); + + let mut deserializer = Deserializer::from_slice(&bytes, DefaultOptions::new()); + let deserializer_with_fds = DeserializerWithFds::new(&mut fds_iter, &mut deserializer); + + Ok(Recv::deserialize(deserializer_with_fds)?) + } +} diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs new file mode 100644 index 0000000..08e1aff --- /dev/null +++ b/msg_socket2/tests/round_trip.rs @@ -0,0 +1,99 @@ +use std::os::unix::prelude::*; + +use std::fmt::{self, Formatter}; +use std::marker::PhantomData; +use std::mem::size_of; + +use msg_socket2::*; +use serde::de::*; +use serde::ser::*; +use sys_util::net::UnixSeqpacket; + +#[derive(Debug)] +struct Inner(RawFd, u16); + +#[derive(Debug)] +struct Test { + fd: RawFd, + inner: Inner, +} + +impl SerializeWithFds for Test { + fn serialize(&self, serializer: SerializerWithFds) -> Result + where + Ser: Serializer, + { + let mut state = serializer + .serializer + .serialize_struct("Test", size_of::())?; + serializer.fds.push(self.fd); + state.skip_field("fd")?; + + struct SerializableInner<'a>(&'a Inner); + + impl<'a> Serialize for SerializableInner<'a> { + fn serialize(&self, serializer: S) -> Result { + let mut state = serializer + .serialize_tuple_struct("Inner", size_of::() - size_of::() * 1)?; + state.serialize_field(&(self.0).1)?; + state.end() + } + } + + serializer.fds.push(self.inner.0); + state.serialize_field("inner", &SerializableInner(&self.inner))?; + + state.end() + } +} + +impl<'de> DeserializeWithFds<'de> for Test { + fn deserialize(deserializer: DeserializerWithFds) -> Result + where + I: Iterator, + De: Deserializer<'de>, + { + struct Visitor<'iter, 'de, Iter>(&'iter mut Iter, PhantomData<&'de ()>); + + impl<'iter, 'de, Iter: Iterator> serde::de::Visitor<'de> + for Visitor<'iter, 'de, Iter> + { + type Value = Test; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct Test") + } + + fn visit_seq>(self, mut seq: A) -> Result { + Ok(Test { + fd: self.0.next().unwrap(), + inner: Inner(self.0.next().unwrap(), seq.next_element()?.unwrap()), + }) + } + } + + let DeserializerWithFds { + mut fds, + deserializer, + } = deserializer; + + let visitor = Visitor(&mut fds, PhantomData); + deserializer.deserialize_struct("Test", &["fd", "inner"], visitor) + } +} + +#[test] +fn round_trip() { + let (f1, f2) = UnixSeqpacket::pair().unwrap(); + let s1: Socket<_, ()> = Socket::new(f1); + let s2: Socket<(), Test> = Socket::new(f2); + + s1.send(Test { + fd: 0, + inner: Inner(1, 0xACAB), + }) + .unwrap(); + + let result = s2.recv().unwrap(); + assert_eq!(result.inner.1, 0xACAB); +} diff --git a/poly_msg_socket/Cargo.toml b/poly_msg_socket/Cargo.toml index acc671e..131e791 100644 --- a/poly_msg_socket/Cargo.toml +++ b/poly_msg_socket/Cargo.toml @@ -7,5 +7,7 @@ edition = "2018" [dependencies] msg_socket = { path = "../msg_socket" } sys_util = { path = "../sys_util" } -bincode = "1.2.1" serde = { version = "1.0.104", features = ["derive"] } + +# Match msg_socket2's bincode. +bincode = { git = "https://github.com/alyssais/bincode", branch = "from_slice" } -- cgit 1.4.1