summary refs log tree commit diff
path: root/msg_socket/src/msg_on_socket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'msg_socket/src/msg_on_socket.rs')
-rw-r--r--msg_socket/src/msg_on_socket.rs278
1 files changed, 278 insertions, 0 deletions
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
new file mode 100644
index 0000000..aca57e2
--- /dev/null
+++ b/msg_socket/src/msg_on_socket.rs
@@ -0,0 +1,278 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use data_model::*;
+use std;
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::result;
+use sys_util::{Error as SysError, EventFd};
+
+use std::fs::File;
+use std::net::{TcpListener, TcpStream, UdpSocket};
+use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
+
+#[derive(Debug, PartialEq)]
+/// An error during transaction or serialization/deserialization.
+pub enum MsgError {
+    /// Error while sending a request or response.
+    Send(SysError),
+    /// Error while receiving a request or response.
+    Recv(SysError),
+    /// The type of a received request or response is unknown.
+    InvalidType,
+    /// There was not the expected amount of data when receiving a message. The inner
+    /// value is how much data is needed.
+    BadRecvSize(usize),
+    /// There was no associated file descriptor received for a request that expected it.
+    ExpectFd,
+    /// There was some associated file descriptor received but not used when deserialize.
+    NotExpectFd,
+    /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens
+    /// when max_fd_count() returns a value that is too small.
+    WrongFdBufferSize,
+    /// Trying to serialize/deserialize, but msg buffer size is too small. This typically happens
+    /// when msg_size() returns a value that is too small.
+    WrongMsgBufferSize,
+}
+
+pub type MsgResult<T> = result::Result<T, MsgError>;
+
+/// A msg that could be serialized to and deserialize from array in little endian.
+///
+/// For structs, we always have fixed size of bytes and fixed count of fds.
+/// For enums, the size needed might be different for each variant.
+///
+/// e.g.
+/// ```
+/// use std::os::unix::io::RawFd;
+/// enum Message {
+///     VariantA(u8),
+///     VariantB(u32, RawFd),
+///     VariantC,
+/// }
+/// ```
+///
+/// For variant A, we need 1 byte to store its inner value.
+/// For variant B, we need 4 bytes and 1 RawFd to store its inner value.
+/// For variant C, we need 0 bytes to store its inner value.
+/// When we serialize Message to (buffer, fd_buffer), we always use fixed number of bytes in
+/// the buffer. Unused buffer bytes will be padded with zero.
+/// However, for fd_buffer, we could not do the same thing. Otherwise, we are essentially sending
+/// fd 0 through the socket.
+/// Thus, read/write functions always the return correct count of fds in this variant. There will be
+/// no padding in fd_buffer.
+pub trait MsgOnSocket: Sized {
+    /// Size of message in bytes.
+    fn msg_size() -> usize;
+    /// Max possible fd count in this type.
+    fn max_fd_count() -> usize {
+        0
+    }
+    /// Returns (self, fd read count).
+    /// This function is safe only when:
+    ///     0. fds contains valid fds, received from socket, serialized by Self::write_to_buffer.
+    ///     1. For enum, fds contains correct fd layout of the particular variant.
+    ///     2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding,
+    ///        return correct count).
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)>;
+    /// Serialize self to buffers.
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize>;
+}
+
+impl MsgOnSocket for SysError {
+    fn msg_size() -> usize {
+        u32::msg_size()
+    }
+    unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        let (v, size) = u32::read_from_buffer(buffer, fds)?;
+        Ok((SysError::new(v as i32), size))
+    }
+    fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        let v = self.errno() as u32;
+        v.write_to_buffer(buffer, fds)
+    }
+}
+
+impl MsgOnSocket for RawFd {
+    fn msg_size() -> usize {
+        0
+    }
+    fn max_fd_count() -> usize {
+        1
+    }
+    unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        if fds.len() < 1 {
+            return Err(MsgError::ExpectFd);
+        }
+        Ok((fds[0], 1))
+    }
+    fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+        if fds.len() < 1 {
+            return Err(MsgError::WrongFdBufferSize);
+        }
+        fds[0] = self.clone();
+        Ok(1)
+    }
+}
+
+macro_rules! rawfd_impl {
+    ($type:ident) => {
+        impl MsgOnSocket for $type {
+            fn msg_size() -> usize {
+                0
+            }
+            fn max_fd_count() -> usize {
+                1
+            }
+            unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                if fds.len() < 1 {
+                    return Err(MsgError::ExpectFd);
+                }
+                Ok(($type::from_raw_fd(fds[0].clone()), 1))
+            }
+            fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
+                if fds.len() < 1 {
+                    return Err(MsgError::WrongFdBufferSize);
+                }
+                fds[0] = self.as_raw_fd();
+                Ok(1)
+            }
+        }
+    };
+}
+
+rawfd_impl!(EventFd);
+rawfd_impl!(File);
+rawfd_impl!(UnixStream);
+rawfd_impl!(TcpStream);
+rawfd_impl!(TcpListener);
+rawfd_impl!(UdpSocket);
+rawfd_impl!(UnixListener);
+rawfd_impl!(UnixDatagram);
+
+// usize could be different sizes on different targets. We always use u64.
+impl MsgOnSocket for usize {
+    fn msg_size() -> usize {
+        std::mem::size_of::<u64>()
+    }
+    unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+        if buffer.len() < std::mem::size_of::<u64>() {
+            return Err(MsgError::WrongMsgBufferSize);
+        }
+        let t: u64 = Le64::from_slice(&buffer[0..Self::msg_size()])
+            .unwrap()
+            .clone()
+            .into();
+        Ok((t as usize, 0))
+    }
+
+    fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> {
+        if buffer.len() < std::mem::size_of::<u64>() {
+            return Err(MsgError::WrongMsgBufferSize);
+        }
+        let t: Le64 = (*self as u64).into();
+        buffer[0..Self::msg_size()].copy_from_slice(t.as_slice());
+        Ok(0)
+    }
+}
+
+macro_rules! le_impl {
+    ($type:ident, $le_type:ident) => {
+        impl MsgOnSocket for $type {
+            fn msg_size() -> usize {
+                std::mem::size_of::<$le_type>()
+            }
+            unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                if buffer.len() < std::mem::size_of::<$le_type>() {
+                    return Err(MsgError::WrongMsgBufferSize);
+                }
+                let t = $le_type::from_slice(&buffer[0..Self::msg_size()])
+                    .unwrap()
+                    .clone();
+                Ok((t.into(), 0))
+            }
+
+            fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawFd]) -> MsgResult<usize> {
+                if buffer.len() < std::mem::size_of::<$le_type>() {
+                    return Err(MsgError::WrongMsgBufferSize);
+                }
+                let t: $le_type = self.clone().into();
+                buffer[0..Self::msg_size()].copy_from_slice(t.as_slice());
+                Ok(0)
+            }
+        }
+    };
+}
+
+le_impl!(u8, u8);
+le_impl!(u16, Le16);
+le_impl!(u32, Le32);
+le_impl!(u64, Le64);
+
+le_impl!(Le16, Le16);
+le_impl!(Le32, Le32);
+le_impl!(Le64, Le64);
+
+macro_rules! array_impls {
+    ($N:expr, $t: ident $($ts:ident)*)
+    => {
+        impl<T: MsgOnSocket + Clone> MsgOnSocket for [T; $N] {
+            fn msg_size() -> usize {
+                T::msg_size() * $N
+            }
+            fn max_fd_count() -> usize {
+                T::max_fd_count() * $N
+            }
+            unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
+                if buffer.len() < Self::msg_size() {
+                    return Err(MsgError::WrongMsgBufferSize);
+                }
+                let mut offset = 0usize;
+                let mut fd_offset = 0usize;
+                let ($t, fd_size) =
+                    T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
+                offset += T::msg_size();
+                fd_offset += fd_size;
+                $(
+                    let ($ts, fd_size) =
+                        T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
+                    offset += T::msg_size();
+                    fd_offset += fd_size;
+                    )*
+                assert_eq!(offset, Self::msg_size());
+                Ok(([$t, $($ts),*], fd_offset))
+            }
+
+            fn write_to_buffer(
+                &self,
+                buffer: &mut [u8],
+                fds: &mut [RawFd],
+                ) -> MsgResult<usize> {
+                if buffer.len() < Self::msg_size() {
+                    return Err(MsgError::WrongMsgBufferSize);
+                }
+                let mut offset = 0usize;
+                let mut fd_offset = 0usize;
+                for idx in 0..$N {
+                    let fd_size = self[idx].clone().write_to_buffer(&mut buffer[offset..],
+                                                            &mut fds[fd_offset..])?;
+                    offset += T::msg_size();
+                    fd_offset += fd_size;
+                }
+
+                Ok(fd_offset)
+            }
+        }
+        array_impls!(($N - 1), $($ts)*);
+    };
+    {$N:expr, } => {};
+}
+
+array_impls! {
+    32, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16
+        tmp17 tmp18 tmp19 tmp20 tmp21 tmp22 tmp23 tmp24 tmp25 tmp26 tmp27 tmp28 tmp29 tmp30 tmp31
+        tmp32
+}
+
+// TODO(jkwang) Define MsgOnSocket for tuple?