diff options
author | Dylan Reid <dgreid@chromium.org> | 2020-01-13 01:59:25 -0800 |
---|---|---|
committer | Commit Bot <commit-bot@chromium.org> | 2020-02-26 06:20:39 +0000 |
commit | 72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df (patch) | |
tree | d29b02393eedaedde8608ff7a6abc6efb3124c4e /msg_socket | |
parent | dfd0139d7cf2935add342c76cec66702800e95b7 (diff) | |
download | crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.gz crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.bz2 crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.lz crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.xz crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.zst crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.zip |
msg_socket: Add async receiving of messages
Add a member to MsgSocket that effectively returns an async iterator over messages received on the socket. This is done by setting the socket as non-blocking and registering with the async infrastructure when the socket would block. This feature will be used by devices that wish to handle messages in an async fn context. Change-Id: I47c6e83922068820cd19ffd9ef604ed8a16b755e Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/1997243 Reviewed-by: Dylan Reid <dgreid@chromium.org> Tested-by: Dylan Reid <dgreid@chromium.org> Tested-by: kokoro <noreply+kokoro@google.com> Commit-Queue: Dylan Reid <dgreid@chromium.org>
Diffstat (limited to 'msg_socket')
-rw-r--r-- | msg_socket/Cargo.toml | 3 | ||||
-rw-r--r-- | msg_socket/src/lib.rs | 88 | ||||
-rw-r--r-- | msg_socket/src/msg_on_socket.rs | 6 |
3 files changed, 86 insertions, 11 deletions
diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml index dcfccfc..c803bed 100644 --- a/msg_socket/Cargo.toml +++ b/msg_socket/Cargo.toml @@ -5,6 +5,9 @@ authors = ["The Chromium OS Authors"] edition = "2018" [dependencies] +cros_async = { path = "../cros_async" } data_model = { path = "../data_model" } +futures = "*" +libc = "*" msg_on_socket_derive = { path = "msg_on_socket_derive" } sys_util = { path = "../sys_util" } diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index 5b9f9ce..ea817f0 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -7,8 +7,17 @@ mod msg_on_socket; use std::io::Result; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; -use sys_util::{handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket}; +use futures::Stream; +use libc::{EWOULDBLOCK, O_NONBLOCK}; + +use cros_async::fd_executor::add_read_waker; +use sys_util::{ + add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError, + ScmSocket, +}; pub use crate::msg_on_socket::*; pub use msg_on_socket_derive::*; @@ -18,16 +27,8 @@ pub use msg_on_socket_derive::*; pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>( ) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> { let (sock1, sock2) = UnixSeqpacket::pair()?; - let requester = MsgSocket { - sock: sock1, - _i: PhantomData, - _o: PhantomData, - }; - let responder = MsgSocket { - sock: sock2, - _i: PhantomData, - _o: PhantomData, - }; + let requester = MsgSocket::new(sock1); + let responder = MsgSocket::new(sock2); Ok((requester, responder)) } @@ -47,6 +48,11 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> { _o: PhantomData, } } + + // Creates an async receiver that implements `futures::Stream`. + pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> { + AsyncReceiver::new(self) + } } /// One direction socket that only supports sending. @@ -191,3 +197,63 @@ impl<I: MsgOnSocket> MsgSender for Sender<I> { impl<O: MsgOnSocket> MsgReceiver for Receiver<O> { type M = O; } + +/// Asynchronous adaptor for `MsgSocket`. +pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> { + inner: &'a mut MsgSocket<I, O>, + done: bool, // Have hit an error and the Stream should return null when polled. +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> { + fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> { + add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?; + Ok(AsyncReceiver { + inner: msg_socket, + done: false, + }) + } +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> { + fn drop(&mut self) { + if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) { + error!( + "Failed to restore non-blocking behavior to message socket: {}", + e + ); + } + } +} + +impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> { + type Item = MsgResult<O>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { + if self.done { + return Poll::Ready(None); + } + + let ret = match self.inner.recv() { + Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))), + Err(MsgError::Recv(e)) => { + if e.errno() == EWOULDBLOCK { + add_read_waker(self.inner.as_raw_fd(), cx.waker().clone()) + .map(|_| Poll::Pending) + .map_err(MsgError::AddingWaker) + } else { + Err(MsgError::Recv(e)) + } + } + Err(e) => Err(e), + }; + + match ret { + Ok(p) => p, + Err(e) => { + // Indicate something went wrong and no more events will be provided. + self.done = true; + Poll::Ready(Some(Err(e))) + } + } + } +} diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs index 2924dc6..f03c36f 100644 --- a/msg_socket/src/msg_on_socket.rs +++ b/msg_socket/src/msg_on_socket.rs @@ -15,6 +15,8 @@ use sys_util::{Error as SysError, EventFd}; #[derive(Debug, PartialEq)] /// An error during transaction or serialization/deserialization. pub enum MsgError { + /// Error adding a waker for async read. + AddingWaker(cros_async::fd_executor::Error), /// Error while sending a request or response. Send(SysError), /// Error while receiving a request or response. @@ -28,6 +30,8 @@ pub enum MsgError { ExpectFd, /// There was some associated file descriptor received but not used when deserialize. NotExpectFd, + /// Failed to set flags on the file descriptor. + SettingFdFlags(SysError), /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens /// when max_fd_count() returns a value that is too small. WrongFdBufferSize, @@ -43,6 +47,7 @@ impl Display for MsgError { use self::MsgError::*; match self { + AddingWaker(e) => write!(f, "failed to add a waker: {}", e), Send(e) => write!(f, "failed to send request or response: {}", e), Recv(e) => write!(f, "failed to receive request or response: {}", e), InvalidType => write!(f, "invalid type"), @@ -53,6 +58,7 @@ impl Display for MsgError { ), ExpectFd => write!(f, "missing associated file descriptor for request"), NotExpectFd => write!(f, "unexpected file descriptor is unused"), + SettingFdFlags(e) => write!(f, "failed setting flags on the message FD: {}", e), WrongFdBufferSize => write!(f, "fd buffer size too small"), WrongMsgBufferSize => write!(f, "msg buffer size too small"), } |