summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-20 05:48:28 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:05 +0000
commit8214c4c64fbdbf6ae84634bb822a90959271cad5 (patch)
tree6d46db38cb233ae7a7cf592b485608af96accf12
parentb76f0d1043ffde3c6525abaecb421c0a4dc4c277 (diff)
downloadcrosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.gz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.bz2
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.lz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.xz
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.tar.zst
crosvm-8214c4c64fbdbf6ae84634bb822a90959271cad5.zip
msg_socket2: initial commit
-rw-r--r--Cargo.lock8
-rw-r--r--msg_socket2/Cargo.toml11
-rw-r--r--msg_socket2/src/de.rs21
-rw-r--r--msg_socket2/src/error.rs23
-rw-r--r--msg_socket2/src/lib.rs34
-rw-r--r--msg_socket2/src/ser.rs20
-rw-r--r--msg_socket2/src/socket.rs48
-rw-r--r--msg_socket2/tests/round_trip.rs99
-rw-r--r--poly_msg_socket/Cargo.toml4
9 files changed, 263 insertions, 5 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 3db8e4e..3b175fe 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -57,8 +57,8 @@ dependencies = [
 
 [[package]]
 name = "bincode"
-version = "1.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
+version = "1.3.0"
+source = "git+https://github.com/alyssais/bincode?branch=from_slice#ed85a66d69be7073f8b484fcc65101149bb31acc"
 dependencies = [
  "byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)",
  "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -574,7 +574,7 @@ dependencies = [
 name = "poly_msg_socket"
 version = "0.1.0"
 dependencies = [
- "bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)",
  "msg_socket 0.1.0",
  "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
  "sys_util 0.1.0",
@@ -859,7 +859,7 @@ dependencies = [
 ]
 
 [metadata]
-"checksum bincode 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "5753e2a71534719bf3f4e57006c3a4f0d2c672a4b676eec84161f763eca87dbf"
+"checksum bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)" = "<none>"
 "checksum bitflags 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3d155346769a6855b86399e9bc3814ab343cd3d62c7e985113d46a0ec3c281fd"
 "checksum byteorder 1.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "08c48aae112d48ed9f069b33538ea9e3e90aa263cfa3d1c24309612b1f7472de"
 "checksum cc 1.0.25 (registry+https://github.com/rust-lang/crates.io-index)" = "f159dfd43363c4d08055a07703eb7a3406b0dac4d0584d96965a3262db3c9d16"
diff --git a/msg_socket2/Cargo.toml b/msg_socket2/Cargo.toml
new file mode 100644
index 0000000..ba167d8
--- /dev/null
+++ b/msg_socket2/Cargo.toml
@@ -0,0 +1,11 @@
+[package]
+name = "msg_socket2"
+version = "0.1.0"
+authors = ["Alyssa Ross <hi@alyssa.is>"]
+edition = "2018"
+
+[dependencies]
+serde = "1.0.104"
+sys_util = { path = "../sys_util" }
+
+bincode = { git = "https://github.com/alyssais/bincode", branch = "from_slice" }
diff --git a/msg_socket2/src/de.rs b/msg_socket2/src/de.rs
new file mode 100644
index 0000000..1d3a9e1
--- /dev/null
+++ b/msg_socket2/src/de.rs
@@ -0,0 +1,21 @@
+use serde::Deserializer;
+use std::os::unix::prelude::*;
+
+pub trait DeserializeWithFds<'de>: Sized {
+    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    where
+        I: Iterator<Item = RawFd>,
+        De: Deserializer<'de>;
+}
+
+#[derive(Debug)]
+pub struct DeserializerWithFds<'iter, Iter, De> {
+    pub deserializer: De,
+    pub fds: &'iter mut Iter,
+}
+
+impl<'iter, Iter, De> DeserializerWithFds<'iter, Iter, De> {
+    pub fn new(fds: &'iter mut Iter, deserializer: De) -> Self {
+        Self { deserializer, fds }
+    }
+}
diff --git a/msg_socket2/src/error.rs b/msg_socket2/src/error.rs
new file mode 100644
index 0000000..902684b
--- /dev/null
+++ b/msg_socket2/src/error.rs
@@ -0,0 +1,23 @@
+#[derive(Debug)]
+pub enum Error {
+    DataError(bincode::Error),
+    IoError(sys_util::Error),
+}
+
+impl From<bincode::Error> for Error {
+    fn from(error: bincode::Error) -> Self {
+        Self::DataError(error)
+    }
+}
+
+impl From<sys_util::Error> for Error {
+    fn from(error: sys_util::Error) -> Self {
+        Self::IoError(error)
+    }
+}
+
+impl From<std::io::Error> for Error {
+    fn from(error: std::io::Error) -> Self {
+        Self::IoError(error.into())
+    }
+}
diff --git a/msg_socket2/src/lib.rs b/msg_socket2/src/lib.rs
new file mode 100644
index 0000000..748a9f7
--- /dev/null
+++ b/msg_socket2/src/lib.rs
@@ -0,0 +1,34 @@
+// Copyright 2020, Alyssa Ross
+// All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//     * Redistributions of source code must retain the above copyright
+//       notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above copyright
+//       notice, this list of conditions and the following disclaimer in the
+//       documentation and/or other materials provided with the distribution.
+//     * Neither the name of the <organization> nor the
+//       names of its contributors may be used to endorse or promote products
+//       derived from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
+// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+mod de;
+mod error;
+mod ser;
+mod socket;
+
+pub use de::{DeserializeWithFds, DeserializerWithFds};
+pub use error::Error;
+pub use ser::{SerializeWithFds, SerializerWithFds};
+pub use socket::Socket;
diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs
new file mode 100644
index 0000000..0a60ea8
--- /dev/null
+++ b/msg_socket2/src/ser.rs
@@ -0,0 +1,20 @@
+use serde::Serializer;
+use std::os::unix::prelude::*;
+
+pub trait SerializeWithFds {
+    fn serialize<Ser>(&self, serializer: SerializerWithFds<Ser>) -> Result<Ser::Ok, Ser::Error>
+    where
+        Ser: Serializer;
+}
+
+#[derive(Debug)]
+pub struct SerializerWithFds<'fds, Ser> {
+    pub serializer: Ser,
+    pub fds: &'fds mut Vec<RawFd>,
+}
+
+impl<'fds, Ser> SerializerWithFds<'fds, Ser> {
+    pub fn new(fds: &'fds mut Vec<RawFd>, serializer: Ser) -> Self {
+        Self { serializer, fds }
+    }
+}
diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs
new file mode 100644
index 0000000..bce587a
--- /dev/null
+++ b/msg_socket2/src/socket.rs
@@ -0,0 +1,48 @@
+use bincode::{DefaultOptions, Serializer, Deserializer};
+use std::marker::PhantomData;
+use std::io::IoSlice;
+use sys_util::{net::UnixSeqpacket, ScmSocket};
+
+use crate::{DeserializerWithFds, DeserializeWithFds, Error, SerializeWithFds, SerializerWithFds};
+
+#[derive(Debug)]
+pub struct Socket<Send, Recv> {
+    sock: UnixSeqpacket,
+    __: PhantomData<(Send, Recv)>,
+}
+
+impl<Send, Recv> Socket<Send, Recv> {
+    pub fn new(sock: UnixSeqpacket) -> Self {
+        Self {
+            sock,
+            __: PhantomData,
+        }
+    }
+}
+
+impl<Send: SerializeWithFds, Recv> Socket<Send, Recv> {
+    pub fn send(&self, value: Send) -> Result<(), Error> {
+        let mut bytes: Vec<u8> = vec![];
+        let mut fds: Vec<RawFd> = vec![];
+
+        let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new());
+        let serializer_with_fds = SerializerWithFds::new(&mut fds, &mut serializer);
+        value.serialize(serializer_with_fds)?;
+
+        self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?;
+
+        Ok(())
+    }
+}
+
+impl<Send, Recv: for<'de> DeserializeWithFds<'de>> Socket<Send, Recv> {
+    pub fn recv(&self) -> Result<Recv, Error> {
+        let (bytes, fds) = self.sock.recv_as_vec_with_fds()?;
+        let mut fds_iter = fds.into_iter();
+
+        let mut deserializer = Deserializer::from_slice(&bytes, DefaultOptions::new());
+        let deserializer_with_fds = DeserializerWithFds::new(&mut fds_iter, &mut deserializer);
+
+        Ok(Recv::deserialize(deserializer_with_fds)?)
+    }
+}
diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs
new file mode 100644
index 0000000..08e1aff
--- /dev/null
+++ b/msg_socket2/tests/round_trip.rs
@@ -0,0 +1,99 @@
+use std::os::unix::prelude::*;
+
+use std::fmt::{self, Formatter};
+use std::marker::PhantomData;
+use std::mem::size_of;
+
+use msg_socket2::*;
+use serde::de::*;
+use serde::ser::*;
+use sys_util::net::UnixSeqpacket;
+
+#[derive(Debug)]
+struct Inner(RawFd, u16);
+
+#[derive(Debug)]
+struct Test {
+    fd: RawFd,
+    inner: Inner,
+}
+
+impl SerializeWithFds for Test {
+    fn serialize<Ser>(&self, serializer: SerializerWithFds<Ser>) -> Result<Ser::Ok, Ser::Error>
+    where
+        Ser: Serializer,
+    {
+        let mut state = serializer
+            .serializer
+            .serialize_struct("Test", size_of::<Test>())?;
+        serializer.fds.push(self.fd);
+        state.skip_field("fd")?;
+
+        struct SerializableInner<'a>(&'a Inner);
+
+        impl<'a> Serialize for SerializableInner<'a> {
+            fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+                let mut state = serializer
+                    .serialize_tuple_struct("Inner", size_of::<Inner>() - size_of::<RawFd>() * 1)?;
+                state.serialize_field(&(self.0).1)?;
+                state.end()
+            }
+        }
+
+        serializer.fds.push(self.inner.0);
+        state.serialize_field("inner", &SerializableInner(&self.inner))?;
+
+        state.end()
+    }
+}
+
+impl<'de> DeserializeWithFds<'de> for Test {
+    fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error>
+    where
+        I: Iterator<Item = RawFd>,
+        De: Deserializer<'de>,
+    {
+        struct Visitor<'iter, 'de, Iter>(&'iter mut Iter, PhantomData<&'de ()>);
+
+        impl<'iter, 'de, Iter: Iterator<Item = RawFd>> serde::de::Visitor<'de>
+            for Visitor<'iter, 'de, Iter>
+        {
+            type Value = Test;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "struct Test")
+            }
+
+            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
+                Ok(Test {
+                    fd: self.0.next().unwrap(),
+                    inner: Inner(self.0.next().unwrap(), seq.next_element()?.unwrap()),
+                })
+            }
+        }
+
+        let DeserializerWithFds {
+            mut fds,
+            deserializer,
+        } = deserializer;
+
+        let visitor = Visitor(&mut fds, PhantomData);
+        deserializer.deserialize_struct("Test", &["fd", "inner"], visitor)
+    }
+}
+
+#[test]
+fn round_trip() {
+    let (f1, f2) = UnixSeqpacket::pair().unwrap();
+    let s1: Socket<_, ()> = Socket::new(f1);
+    let s2: Socket<(), Test> = Socket::new(f2);
+
+    s1.send(Test {
+        fd: 0,
+        inner: Inner(1, 0xACAB),
+    })
+    .unwrap();
+
+    let result = s2.recv().unwrap();
+    assert_eq!(result.inner.1, 0xACAB);
+}
diff --git a/poly_msg_socket/Cargo.toml b/poly_msg_socket/Cargo.toml
index acc671e..131e791 100644
--- a/poly_msg_socket/Cargo.toml
+++ b/poly_msg_socket/Cargo.toml
@@ -7,5 +7,7 @@ edition = "2018"
 [dependencies]
 msg_socket = { path = "../msg_socket" }
 sys_util = { path = "../sys_util" }
-bincode = "1.2.1"
 serde = { version = "1.0.104", features = ["derive"] }
+
+# Match msg_socket2's bincode.
+bincode = { git = "https://github.com/alyssais/bincode", branch = "from_slice" }