// Copyright 2018 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. mod msg_on_socket; use std::io::{IoSlice, Result}; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; use std::pin::Pin; use std::task::{Context, Poll}; use futures::Stream; use libc::{EWOULDBLOCK, O_NONBLOCK}; use cros_async::add_read_waker; use sys_util::{ add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket, }; pub use crate::msg_on_socket::*; pub use msg_on_socket_derive::*; /// Create a pair of socket. Request is send in one direction while response is in the other /// direction. pub fn pair( ) -> Result<(MsgSocket, MsgSocket)> { let (sock1, sock2) = UnixSeqpacket::pair()?; let requester = MsgSocket::new(sock1); let responder = MsgSocket::new(sock2); Ok((requester, responder)) } /// Bidirection sock that support both send and recv. pub struct MsgSocket { sock: UnixSeqpacket, _i: PhantomData, _o: PhantomData, } impl MsgSocket { // Create a new MsgSocket. pub fn new(s: UnixSeqpacket) -> MsgSocket { MsgSocket { sock: s, _i: PhantomData, _o: PhantomData, } } // Creates an async receiver that implements `futures::Stream`. pub fn async_receiver(&self) -> MsgResult> { AsyncReceiver::new(self) } } /// One direction socket that only supports sending. pub struct Sender { sock: UnixSeqpacket, _m: PhantomData, } impl Sender { /// Create a new sender sock. pub fn new(s: UnixSeqpacket) -> Sender { Sender { sock: s, _m: PhantomData, } } } /// One direction socket that only supports receiving. pub struct Receiver { sock: UnixSeqpacket, _m: PhantomData, } impl Receiver { /// Create a new receiver sock. pub fn new(s: UnixSeqpacket) -> Receiver { Receiver { sock: s, _m: PhantomData, } } } impl AsRef for MsgSocket { fn as_ref(&self) -> &UnixSeqpacket { &self.sock } } impl AsRawFd for MsgSocket { fn as_raw_fd(&self) -> RawFd { self.sock.as_raw_fd() } } impl AsRef for Sender { fn as_ref(&self) -> &UnixSeqpacket { &self.sock } } impl AsRawFd for Sender { fn as_raw_fd(&self) -> RawFd { self.sock.as_raw_fd() } } impl AsRef for Receiver { fn as_ref(&self) -> &UnixSeqpacket { &self.sock } } impl AsRawFd for Receiver { fn as_raw_fd(&self) -> RawFd { self.sock.as_raw_fd() } } /// Types that could send a message. pub trait MsgSender: AsRef { type M: MsgOnSocket; fn send(&self, msg: &Self::M) -> MsgResult<()> { let msg_size = msg.msg_size(); let fd_size = msg.fd_count(); let mut msg_buffer: Vec = vec![0; msg_size]; let mut fd_buffer: Vec = vec![0; fd_size]; let fd_size = msg.write_to_buffer(&mut msg_buffer, &mut fd_buffer)?; let sock: &UnixSeqpacket = self.as_ref(); if fd_size == 0 { handle_eintr!(sock.send(&msg_buffer)) .map_err(|e| MsgError::Send(SysError::new(e.raw_os_error().unwrap_or(0))))?; } else { let ioslice = IoSlice::new(&msg_buffer[..]); sock.send_with_fds(&[ioslice], &fd_buffer[0..fd_size]) .map_err(MsgError::Send)?; } Ok(()) } } /// Types that could receive a message. pub trait MsgReceiver: AsRef { type M: MsgOnSocket; fn recv(&self) -> MsgResult { let sock: &UnixSeqpacket = self.as_ref(); let (msg_buffer, fd_buffer) = { if Self::M::uses_fd() { sock.recv_as_vec_with_fds() .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))? } else { ( sock.recv_as_vec().map_err(|e| { MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))) })?, vec![], ) } }; if msg_buffer.len() == 0 && Self::M::fixed_size() != Some(0) { return Err(MsgError::RecvZero); } if let Some(fixed_size) = Self::M::fixed_size() { if fixed_size != msg_buffer.len() { return Err(MsgError::BadRecvSize { expected: fixed_size, actual: msg_buffer.len(), }); } } // Safe because fd buffer is read from socket. let (v, read_fd_size) = unsafe { Self::M::read_from_buffer(&msg_buffer, &fd_buffer)? }; if fd_buffer.len() != read_fd_size { return Err(MsgError::NotExpectFd); } Ok(v) } } impl MsgSender for MsgSocket { type M = I; } impl MsgReceiver for MsgSocket { type M = O; } impl MsgSender for Sender { type M = I; } impl MsgReceiver for Receiver { type M = O; } /// Asynchronous adaptor for `MsgSocket`. pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> { inner: &'a MsgSocket, done: bool, // Have hit an error and the Stream should return null when polled. } impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> { fn new(msg_socket: &MsgSocket) -> MsgResult> { add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?; Ok(AsyncReceiver { inner: msg_socket, done: false, }) } } impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> { fn drop(&mut self) { if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) { error!( "Failed to restore non-blocking behavior to message socket: {}", e ); } } } impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> { type Item = MsgResult; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { if self.done { return Poll::Ready(None); } let ret = match self.inner.recv() { Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))), Err(MsgError::Recv(e)) => { if e.errno() == EWOULDBLOCK { add_read_waker(self.inner.as_raw_fd(), cx.waker().clone()) .map(|_| Poll::Pending) .map_err(MsgError::AddingWaker) } else { Err(MsgError::Recv(e)) } } Err(e) => Err(e), }; match ret { Ok(p) => p, Err(e) => { // Indicate something went wrong and no more events will be provided. self.done = true; Poll::Ready(Some(Err(e))) } } } }