diff options
-rw-r--r-- | poly_msg_socket/Cargo.toml | 11 | ||||
-rw-r--r-- | poly_msg_socket/src/lib.rs | 306 |
2 files changed, 317 insertions, 0 deletions
diff --git a/poly_msg_socket/Cargo.toml b/poly_msg_socket/Cargo.toml new file mode 100644 index 0000000..acc671e --- /dev/null +++ b/poly_msg_socket/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "poly_msg_socket" +version = "0.1.0" +authors = ["Alyssa Ross <hi@alyssa.is>"] +edition = "2018" + +[dependencies] +msg_socket = { path = "../msg_socket" } +sys_util = { path = "../sys_util" } +bincode = "1.2.1" +serde = { version = "1.0.104", features = ["derive"] } diff --git a/poly_msg_socket/src/lib.rs b/poly_msg_socket/src/lib.rs new file mode 100644 index 0000000..f22a44d --- /dev/null +++ b/poly_msg_socket/src/lib.rs @@ -0,0 +1,306 @@ +// Copyright 2020, Alyssa Ross +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of the <organization> nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::os::unix::prelude::*; + +use std::borrow::Borrow; +use std::fmt::{self, Display, Formatter}; +use std::marker::PhantomData; + +use msg_socket::{MsgError, MsgOnSocket, MsgReceiver, MsgResult, MsgSender}; +use serde::{de::DeserializeOwned, Serialize}; +use sys_util::{net::UnixSeqpacket, Error as SysError}; + +// All of PolyMsgSocket's internals actually live in this inner +// struct. This is because, by implementing MsgSender and MsgReceiver +// on this struct, PolyMsgSocket can access the trait default +// implementations of send and recv, while still overriding them to do +// other things before and after. +struct Inner<Mi, Mo, Bi, Bo> { + sock: UnixSeqpacket, + + _mi: PhantomData<Mi>, + _mo: PhantomData<Mo>, + _bi: PhantomData<Bi>, + _bo: PhantomData<Bo>, +} + +impl<Mi, Mo, Bi, Bo> AsRef<UnixSeqpacket> for Inner<Mi, Mo, Bi, Bo> { + fn as_ref(&self) -> &UnixSeqpacket { + &self.sock + } +} + +impl<Mi: MsgOnSocket, Mo, Bi, Bo> MsgSender for Inner<Mi, Mo, Bi, Bo> { + type M = Mi; +} + +impl<Mi, Mo: MsgOnSocket, Bi, Bo> MsgReceiver for Inner<Mi, Mo, Bi, Bo> { + type M = Mo; +} + +/// A `MsgSocket`-style interface that can use either msg_socket or +/// bincode-based serialization. Each format has an advantage over +/// the other—msg_socket can send file descriptors, and bincode can +/// send dynamically sized data like `Vec`s. +/// +/// It would be possible to accomplish the same result easily using +/// two seperate sockets, but `PolyMsgSocket` essentially multiplexes +/// between msg_socket and bincode. It takes three generic +/// parameters, input and output types for msg_socket (`Mi` and `Mo`), +/// and an input type for bincode (`Bo`). No bincode input type is +/// required. +pub struct PolyMsgSocket<Mi, Mo, Bi, Bo>(Inner<Mi, Mo, Bi, Bo>); + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u8)] +enum Format { + MsgOnSocket, + Bincode, +} + +/// A generic type representing a value in one of the formats +/// supported by a `PolyMsgSocket`. +#[derive(Debug)] +pub enum Value<M, B> { + MsgOnSocket(M), + Bincode(B), +} + +#[derive(Debug)] +pub enum Error { + BadFormat(Vec<u8>), + BincodeError(bincode::Error), + IoError(std::io::Error), + MsgSocketError(MsgError), +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + Self::BadFormat(_) => write!(f, "received message in unexpected format"), + Self::BincodeError(e) => write!(f, "serialization error: {}", e), + Self::IoError(e) => write!(f, "socket error: {}", e), + Self::MsgSocketError(e) => e.fmt(f), + } + } +} + +impl From<std::io::Error> for Error { + fn from(error: std::io::Error) -> Self { + Self::IoError(error) + } +} + +impl From<bincode::Error> for Error { + fn from(error: bincode::Error) -> Self { + Self::BincodeError(error) + } +} + +impl From<MsgError> for Error { + fn from(error: MsgError) -> Self { + Self::MsgSocketError(error) + } +} + +impl<Mi, Mo, Bi: Serialize, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> { + pub fn new(sock: UnixSeqpacket) -> Self { + Self(Inner { + sock, + + _mi: PhantomData, + _mo: PhantomData, + _bi: PhantomData, + _bo: PhantomData, + }) + } + + pub fn send_bincode<T: Borrow<Bi>>(&self, data: T) -> Result<(), Error> { + self.0.sock.send(&[Format::Bincode as u8])?; + self.0.sock.send(&bincode::serialize(data.borrow())?)?; + + Ok(()) + } +} + +impl<Mi, Mo, Bi, Bo> AsRef<UnixSeqpacket> for PolyMsgSocket<Mi, Mo, Bi, Bo> { + fn as_ref(&self) -> &UnixSeqpacket { + self.0.as_ref() + } +} + +impl<Mi, Mo, Bi, Bo> AsRawFd for PolyMsgSocket<Mi, Mo, Bi, Bo> { + fn as_raw_fd(&self) -> RawFd { + self.0.sock.as_raw_fd() + } +} + +impl<Mi: MsgOnSocket, Mo, Bi: Serialize, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> { + /// A generic "send" operation, taking either the msg_socket or + /// bincode type for the socket, contained within a Value. + pub fn send<T: Into<Value<Mi, Bi>>>(&self, data: T) -> Result<(), Error> { + match data.into() { + Value::MsgOnSocket(ref m) => MsgSender::send(self, m)?, + Value::Bincode(ref b) => self.send_bincode(b)?, + } + + Ok(()) + } +} + +impl<Mi, Mo: MsgOnSocket, Bi, Bo> PolyMsgSocket<Mi, Mo, Bi, Bo> { + pub fn recv_msg_on_socket(&self) -> Result<Mo, Error> { + Ok(self.recv()?) + } +} + +impl<Mi, Mo, Bi, Bo: DeserializeOwned> PolyMsgSocket<Mi, Mo, Bi, Bo> { + pub fn recv_bincode(&self) -> Result<Bo, Error> { + let header = self.0.sock.recv_as_vec()?; + if header != [Format::Bincode as u8] { + return Err(Error::BadFormat(header)); + } + + Ok(bincode::deserialize(&self.0.sock.recv_as_vec()?)?) + } +} + +impl<Mi, Mo: MsgOnSocket, Bi, Bo: DeserializeOwned> PolyMsgSocket<Mi, Mo, Bi, Bo> { + /// A generic receive operation, returning either the msg_socket + /// or bincode type for the socket. + pub fn recv(&self) -> Result<Value<Mo, Bo>, Error> { + let mut buf = [0xFF]; + let size = self.0.sock.recv(&mut buf)?; + + if size != 1 { + return Err(Error::BadFormat(vec![])); + } + + match buf[0] { + v if v == Format::MsgOnSocket as u8 => Ok(Value::MsgOnSocket(self.0.recv()?)), + v if v == Format::Bincode as u8 => Ok(Value::Bincode(bincode::deserialize( + &self.0.sock.recv_as_vec()?, + )?)), + _ => Err(Error::BadFormat((Box::new(buf) as Box<[u8]>).into_vec())), + } + } +} + +fn send_error(error: std::io::Error) -> MsgError { + MsgError::Send(SysError::new(error.raw_os_error().unwrap_or(0))) +} + +impl<Mi, Mo: MsgOnSocket, Bi, Bo> MsgReceiver for PolyMsgSocket<Mi, Mo, Bi, Bo> { + type M = Mo; + + fn recv(&self) -> MsgResult<Self::M> { + if self.0.sock.recv_as_vec().map_err(send_error)? != [Format::MsgOnSocket as u8] { + return Err(MsgError::InvalidType); + } + + self.0.recv() + } +} + +impl<Mi: MsgOnSocket, Mo, Bi, Bo> MsgSender for PolyMsgSocket<Mi, Mo, Bi, Bo> { + type M = Mi; + + fn send(&self, msg: &Self::M) -> MsgResult<()> { + self.0 + .sock + .send(&[Format::MsgOnSocket as u8]) + .map_err(send_error)?; + self.0.send(msg) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + + #[derive(Debug, Eq, MsgOnSocket, PartialEq)] + enum MsgOnSocketTest { + Case1, + Case2, + } + + #[derive(Debug, Deserialize, Eq, Serialize, PartialEq)] + enum BincodeTest { + Case1, + Case2, + } + + #[test] + fn msg_socket_send_recv() { + let (i, o) = UnixSeqpacket::pair().unwrap(); + let mi: PolyMsgSocket<MsgOnSocketTest, (), (), ()> = PolyMsgSocket::new(i); + let mo: PolyMsgSocket<(), MsgOnSocketTest, (), ()> = PolyMsgSocket::new(o); + + MsgSender::send(&mi, &MsgOnSocketTest::Case2).unwrap(); + assert_eq!(MsgReceiver::recv(&mo), Ok(MsgOnSocketTest::Case2)); + } + + #[test] + fn bincode_send_recv() { + let (i, o) = UnixSeqpacket::pair().unwrap(); + let bi: PolyMsgSocket<(), (), BincodeTest, ()> = PolyMsgSocket::new(i); + let bo: PolyMsgSocket<(), (), (), BincodeTest> = PolyMsgSocket::new(o); + + bi.send_bincode(BincodeTest::Case2).unwrap(); + assert_eq!(bo.recv_bincode().unwrap(), BincodeTest::Case2); + } + + #[test] + fn generic_send_recv_msg_socket() { + let (i, o) = UnixSeqpacket::pair().unwrap(); + let mi: PolyMsgSocket<MsgOnSocketTest, (), (), ()> = PolyMsgSocket::new(i); + let mo: PolyMsgSocket<(), MsgOnSocketTest, (), ()> = PolyMsgSocket::new(o); + + mi.send(Value::<_, ()>::MsgOnSocket(MsgOnSocketTest::Case2)) + .unwrap(); + + match mo.recv() { + Ok(Value::MsgOnSocket(x)) => assert_eq!(x, MsgOnSocketTest::Case2), + other => panic!("{:?}", other), + } + } + + #[test] + fn generic_send_recv_bincode() { + let (i, o) = UnixSeqpacket::pair().unwrap(); + let bi: PolyMsgSocket<(), (), BincodeTest, ()> = PolyMsgSocket::new(i); + let bo: PolyMsgSocket<(), (), (), BincodeTest> = PolyMsgSocket::new(o); + + bi.send(Value::<_, BincodeTest>::Bincode(BincodeTest::Case2)) + .unwrap(); + + match bo.recv() { + Ok(Value::Bincode(x)) => assert_eq!(x, BincodeTest::Case2), + other => panic!("{:?}", other), + } + } +} |