summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--poly_msg_socket/Cargo.toml11
-rw-r--r--poly_msg_socket/src/lib.rs306
2 files changed, 317 insertions, 0 deletions
diff --git a/poly_msg_socket/Cargo.toml b/poly_msg_socket/Cargo.toml
new file mode 100644
index 0000000..acc671e
--- /dev/null
+++ b/poly_msg_socket/Cargo.toml
@@ -0,0 +1,11 @@
+[package]
+name = "poly_msg_socket"
+version = "0.1.0"
+authors = ["Alyssa Ross <hi@alyssa.is>"]
+edition = "2018"
+
+[dependencies]
+msg_socket = { path = "../msg_socket" }
+sys_util = { path = "../sys_util" }
+bincode = "1.2.1"
+serde = { version = "1.0.104", features = ["derive"] }
diff --git a/poly_msg_socket/src/lib.rs b/poly_msg_socket/src/lib.rs
new file mode 100644
index 0000000..f22a44d
--- /dev/null
+++ b/poly_msg_socket/src/lib.rs
@@ -0,0 +1,306 @@
+// Copyright 2020, Alyssa Ross
+// All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//     * Redistributions of source code must retain the above copyright
+//       notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above copyright
+//       notice, this list of conditions and the following disclaimer in the
+//       documentation and/or other materials provided with the distribution.
+//     * Neither the name of the <organization> nor the
+//       names of its contributors may be used to endorse or promote products
+//       derived from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
+// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+use std::os::unix::prelude::*;
+
+use std::borrow::Borrow;
+use std::fmt::{self, Display, Formatter};
+use std::marker::PhantomData;
+
+use msg_socket::{MsgError, MsgOnSocket, MsgReceiver, MsgResult, MsgSender};
+use serde::{de::DeserializeOwned, Serialize};
+use sys_util::{net::UnixSeqpacket, Error as SysError};
+
+// All of PolyMsgSocket's internals actually live in this inner
+// struct.  This is because, by implementing MsgSender and MsgReceiver
+// on this struct, PolyMsgSocket can access the trait default
+// implementations of send and recv, while still overriding them to do
+// other things before and after.
+struct Inner<Mi, Mo, Bi, Bo> {
+    sock: UnixSeqpacket,
+
+    _mi: PhantomData<Mi>,
+    _mo: PhantomData<Mo>,
+    _bi: PhantomData<Bi>,
+    _bo: PhantomData<Bo>,
+}
+
+impl<Mi, Mo, Bi, Bo> AsRef<UnixSeqpacket> for Inner<Mi, Mo, Bi, Bo> {
+    fn as_ref(&self) -> &UnixSeqpacket {
+        &self.sock
+    }
+}
+
+impl<Mi: MsgOnSocket, Mo, Bi, Bo> MsgSender for Inner<Mi, Mo, Bi, Bo> {
+    type M = Mi;
+}
+
+impl<Mi, Mo: MsgOnSocket, Bi, Bo> MsgReceiver for Inner<Mi, Mo, Bi, Bo> {
+    type M = Mo;
+}
+
+/// A `MsgSocket`-style interface that can use either msg_socket or
+/// bincode-based serialization.  Each format has an advantage over
+/// the other—msg_socket can send file descriptors, and bincode can
+/// send dynamically sized data like `Vec`s.
+///
+/// It would be possible to accomplish the same result easily using
+/// two seperate sockets, but `PolyMsgSocket` essentially multiplexes
+/// between msg_socket and bincode.  It takes three generic
+/// parameters, input and output types for msg_socket (`Mi` and `Mo`),
+/// and an input type for bincode (`Bo`).  No bincode input type is
+/// required.
+pub struct PolyMsgSocket<Mi, Mo, Bi, Bo>(Inner<Mi, Mo, Bi, Bo>);
+
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+#[repr(u8)]
+enum Format {
+    MsgOnSocket,
+    Bincode,
+}
+
+/// A generic type representing a value in one of the formats
+/// supported by a `PolyMsgSocket`.
+#[derive(Debug)]
+pub enum Value<M, B> {
+    MsgOnSocket(M),
+    Bincode(B),
+}
+
+#[derive(Debug)]
+pub enum Error {
+    BadFormat(Vec<u8>),
+    BincodeError(bincode::Error),
+    IoError(std::io::Error),
+    MsgSocketError(MsgError),
+}
+
+impl Display for Error {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        match self {
+            Self::BadFormat(_) => write!(f, "received message in unexpected format"),
+            Self::BincodeError(e) => write!(f, "serialization error: {}", e),
+            Self::IoError(e) => write!(f, "socket error: {}", e),
+            Self::MsgSocketError(e) => e.fmt(f),
+        }
+    }
+}
+
+impl From<std::io::Error> for Error {
+    fn from(error: std::io::Error) -> Self {
+        Self::IoError(error)
+    }
+}
+
+impl From<bincode::Error> for Error {
+    fn from(error: bincode::Error) -> Self {
+        Self::BincodeError(error)
+    }
+}
+
+impl From<MsgError> for Error {
+    fn from(error: MsgError) -> Self {
+        Self::MsgSocketError(error)
+    }
+}
+
+impl<Mi, Mo, Bi: Serialize, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    pub fn new(sock: UnixSeqpacket) -> Self {
+        Self(Inner {
+            sock,
+
+            _mi: PhantomData,
+            _mo: PhantomData,
+            _bi: PhantomData,
+            _bo: PhantomData,
+        })
+    }
+
+    pub fn send_bincode<T: Borrow<Bi>>(&self, data: T) -> Result<(), Error> {
+        self.0.sock.send(&[Format::Bincode as u8])?;
+        self.0.sock.send(&bincode::serialize(data.borrow())?)?;
+
+        Ok(())
+    }
+}
+
+impl<Mi, Mo, Bi, Bo> AsRef<UnixSeqpacket> for PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    fn as_ref(&self) -> &UnixSeqpacket {
+        self.0.as_ref()
+    }
+}
+
+impl<Mi, Mo, Bi, Bo> AsRawFd for PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    fn as_raw_fd(&self) -> RawFd {
+        self.0.sock.as_raw_fd()
+    }
+}
+
+impl<Mi: MsgOnSocket, Mo, Bi: Serialize, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    /// A generic "send" operation, taking either the msg_socket or
+    /// bincode type for the socket, contained within a Value.
+    pub fn send<T: Into<Value<Mi, Bi>>>(&self, data: T) -> Result<(), Error> {
+        match data.into() {
+            Value::MsgOnSocket(ref m) => MsgSender::send(self, m)?,
+            Value::Bincode(ref b) => self.send_bincode(b)?,
+        }
+
+        Ok(())
+    }
+}
+
+impl<Mi, Mo: MsgOnSocket, Bi, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    pub fn recv_msg_on_socket(&self) -> Result<Mo, Error> {
+        Ok(self.recv()?)
+    }
+}
+
+impl<Mi, Mo, Bi, Bo: DeserializeOwned> PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    pub fn recv_bincode(&self) -> Result<Bo, Error> {
+        let header = self.0.sock.recv_as_vec()?;
+        if header != [Format::Bincode as u8] {
+            return Err(Error::BadFormat(header));
+        }
+
+        Ok(bincode::deserialize(&self.0.sock.recv_as_vec()?)?)
+    }
+}
+
+impl<Mi, Mo: MsgOnSocket, Bi, Bo: DeserializeOwned> PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    /// A generic receive operation, returning either the msg_socket
+    /// or bincode type for the socket.
+    pub fn recv(&self) -> Result<Value<Mo, Bo>, Error> {
+        let mut buf = [0xFF];
+        let size = self.0.sock.recv(&mut buf)?;
+
+        if size != 1 {
+            return Err(Error::BadFormat(vec![]));
+        }
+
+        match buf[0] {
+            v if v == Format::MsgOnSocket as u8 => Ok(Value::MsgOnSocket(self.0.recv()?)),
+            v if v == Format::Bincode as u8 => Ok(Value::Bincode(bincode::deserialize(
+                &self.0.sock.recv_as_vec()?,
+            )?)),
+            _ => Err(Error::BadFormat((Box::new(buf) as Box<[u8]>).into_vec())),
+        }
+    }
+}
+
+fn send_error(error: std::io::Error) -> MsgError {
+    MsgError::Send(SysError::new(error.raw_os_error().unwrap_or(0)))
+}
+
+impl<Mi, Mo: MsgOnSocket, Bi, Bo> MsgReceiver for PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    type M = Mo;
+
+    fn recv(&self) -> MsgResult<Self::M> {
+        if self.0.sock.recv_as_vec().map_err(send_error)? != [Format::MsgOnSocket as u8] {
+            return Err(MsgError::InvalidType);
+        }
+
+        self.0.recv()
+    }
+}
+
+impl<Mi: MsgOnSocket, Mo, Bi, Bo> MsgSender for PolyMsgSocket<Mi, Mo, Bi, Bo> {
+    type M = Mi;
+
+    fn send(&self, msg: &Self::M) -> MsgResult<()> {
+        self.0
+            .sock
+            .send(&[Format::MsgOnSocket as u8])
+            .map_err(send_error)?;
+        self.0.send(msg)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use serde::Deserialize;
+
+    #[derive(Debug, Eq, MsgOnSocket, PartialEq)]
+    enum MsgOnSocketTest {
+        Case1,
+        Case2,
+    }
+
+    #[derive(Debug, Deserialize, Eq, Serialize, PartialEq)]
+    enum BincodeTest {
+        Case1,
+        Case2,
+    }
+
+    #[test]
+    fn msg_socket_send_recv() {
+        let (i, o) = UnixSeqpacket::pair().unwrap();
+        let mi: PolyMsgSocket<MsgOnSocketTest, (), (), ()> = PolyMsgSocket::new(i);
+        let mo: PolyMsgSocket<(), MsgOnSocketTest, (), ()> = PolyMsgSocket::new(o);
+
+        MsgSender::send(&mi, &MsgOnSocketTest::Case2).unwrap();
+        assert_eq!(MsgReceiver::recv(&mo), Ok(MsgOnSocketTest::Case2));
+    }
+
+    #[test]
+    fn bincode_send_recv() {
+        let (i, o) = UnixSeqpacket::pair().unwrap();
+        let bi: PolyMsgSocket<(), (), BincodeTest, ()> = PolyMsgSocket::new(i);
+        let bo: PolyMsgSocket<(), (), (), BincodeTest> = PolyMsgSocket::new(o);
+
+        bi.send_bincode(BincodeTest::Case2).unwrap();
+        assert_eq!(bo.recv_bincode().unwrap(), BincodeTest::Case2);
+    }
+
+    #[test]
+    fn generic_send_recv_msg_socket() {
+        let (i, o) = UnixSeqpacket::pair().unwrap();
+        let mi: PolyMsgSocket<MsgOnSocketTest, (), (), ()> = PolyMsgSocket::new(i);
+        let mo: PolyMsgSocket<(), MsgOnSocketTest, (), ()> = PolyMsgSocket::new(o);
+
+        mi.send(Value::<_, ()>::MsgOnSocket(MsgOnSocketTest::Case2))
+            .unwrap();
+
+        match mo.recv() {
+            Ok(Value::MsgOnSocket(x)) => assert_eq!(x, MsgOnSocketTest::Case2),
+            other => panic!("{:?}", other),
+        }
+    }
+
+    #[test]
+    fn generic_send_recv_bincode() {
+        let (i, o) = UnixSeqpacket::pair().unwrap();
+        let bi: PolyMsgSocket<(), (), BincodeTest, ()> = PolyMsgSocket::new(i);
+        let bo: PolyMsgSocket<(), (), (), BincodeTest> = PolyMsgSocket::new(o);
+
+        bi.send(Value::<_, BincodeTest>::Bincode(BincodeTest::Case2))
+            .unwrap();
+
+        match bo.recv() {
+            Ok(Value::Bincode(x)) => assert_eq!(x, BincodeTest::Case2),
+            other => panic!("{:?}", other),
+        }
+    }
+}