diff options
Diffstat (limited to 'msg_socket/src/lib.rs')
-rw-r--r-- | msg_socket/src/lib.rs | 88 |
1 files changed, 77 insertions, 11 deletions
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index 5b9f9ce..ea817f0 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -7,8 +7,17 @@ mod msg_on_socket; use std::io::Result; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; -use sys_util::{handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket}; +use futures::Stream; +use libc::{EWOULDBLOCK, O_NONBLOCK}; + +use cros_async::fd_executor::add_read_waker; +use sys_util::{ + add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError, + ScmSocket, +}; pub use crate::msg_on_socket::*; pub use msg_on_socket_derive::*; @@ -18,16 +27,8 @@ pub use msg_on_socket_derive::*; pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>( ) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> { let (sock1, sock2) = UnixSeqpacket::pair()?; - let requester = MsgSocket { - sock: sock1, - _i: PhantomData, - _o: PhantomData, - }; - let responder = MsgSocket { - sock: sock2, - _i: PhantomData, - _o: PhantomData, - }; + let requester = MsgSocket::new(sock1); + let responder = MsgSocket::new(sock2); Ok((requester, responder)) } @@ -47,6 +48,11 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> { _o: PhantomData, } } + + // Creates an async receiver that implements `futures::Stream`. + pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> { + AsyncReceiver::new(self) + } } /// One direction socket that only supports sending. @@ -191,3 +197,63 @@ impl<I: MsgOnSocket> MsgSender for Sender<I> { impl<O: MsgOnSocket> MsgReceiver for Receiver<O> { type M = O; } + +/// Asynchronous adaptor for `MsgSocket`. +pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> { + inner: &'a mut MsgSocket<I, O>, + done: bool, // Have hit an error and the Stream should return null when polled. +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> { + fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> { + add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?; + Ok(AsyncReceiver { + inner: msg_socket, + done: false, + }) + } +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> { + fn drop(&mut self) { + if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) { + error!( + "Failed to restore non-blocking behavior to message socket: {}", + e + ); + } + } +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> { + type Item = MsgResult<O>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { + if self.done { + return Poll::Ready(None); + } + + let ret = match self.inner.recv() { + Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))), + Err(MsgError::Recv(e)) => { + if e.errno() == EWOULDBLOCK { + add_read_waker(self.inner.as_raw_fd(), cx.waker().clone()) + .map(|_| Poll::Pending) + .map_err(MsgError::AddingWaker) + } else { + Err(MsgError::Recv(e)) + } + } + Err(e) => Err(e), + }; + + match ret { + Ok(p) => p, + Err(e) => { + // Indicate something went wrong and no more events will be provided. + self.done = true; + Poll::Ready(Some(Err(e))) + } + } + } +} |