summary refs log tree commit diff
path: root/msg_socket
diff options
context:
space:
mode:
authorDylan Reid <dgreid@chromium.org>2020-01-13 01:59:25 -0800
committerCommit Bot <commit-bot@chromium.org>2020-02-26 06:20:39 +0000
commit72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df (patch)
treed29b02393eedaedde8608ff7a6abc6efb3124c4e /msg_socket
parentdfd0139d7cf2935add342c76cec66702800e95b7 (diff)
downloadcrosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.gz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.bz2
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.lz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.xz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.zst
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.zip
msg_socket: Add async receiving of messages
Add a member to MsgSocket that effectively returns an async iterator
over messages received on the socket. This is done by setting the socket
as non-blocking and registering with the async infrastructure when the
socket would block.

This feature will be used by devices that wish to handle messages in an
async fn context.

Change-Id: I47c6e83922068820cd19ffd9ef604ed8a16b755e
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/1997243
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: Dylan Reid <dgreid@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Dylan Reid <dgreid@chromium.org>
Diffstat (limited to 'msg_socket')
-rw-r--r--msg_socket/Cargo.toml3
-rw-r--r--msg_socket/src/lib.rs88
-rw-r--r--msg_socket/src/msg_on_socket.rs6
3 files changed, 86 insertions, 11 deletions
diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml
index dcfccfc..c803bed 100644
--- a/msg_socket/Cargo.toml
+++ b/msg_socket/Cargo.toml
@@ -5,6 +5,9 @@ authors = ["The Chromium OS Authors"]
 edition = "2018"
 
 [dependencies]
+cros_async = { path = "../cros_async" }
 data_model = { path = "../data_model" }
+futures = "*"
+libc = "*"
 msg_on_socket_derive = { path = "msg_on_socket_derive" }
 sys_util = { path = "../sys_util"  }
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
index 5b9f9ce..ea817f0 100644
--- a/msg_socket/src/lib.rs
+++ b/msg_socket/src/lib.rs
@@ -7,8 +7,17 @@ mod msg_on_socket;
 use std::io::Result;
 use std::marker::PhantomData;
 use std::os::unix::io::{AsRawFd, RawFd};
+use std::pin::Pin;
+use std::task::{Context, Poll};
 
-use sys_util::{handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket};
+use futures::Stream;
+use libc::{EWOULDBLOCK, O_NONBLOCK};
+
+use cros_async::fd_executor::add_read_waker;
+use sys_util::{
+    add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError,
+    ScmSocket,
+};
 
 pub use crate::msg_on_socket::*;
 pub use msg_on_socket_derive::*;
@@ -18,16 +27,8 @@ pub use msg_on_socket_derive::*;
 pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>(
 ) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> {
     let (sock1, sock2) = UnixSeqpacket::pair()?;
-    let requester = MsgSocket {
-        sock: sock1,
-        _i: PhantomData,
-        _o: PhantomData,
-    };
-    let responder = MsgSocket {
-        sock: sock2,
-        _i: PhantomData,
-        _o: PhantomData,
-    };
+    let requester = MsgSocket::new(sock1);
+    let responder = MsgSocket::new(sock2);
     Ok((requester, responder))
 }
 
@@ -47,6 +48,11 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
             _o: PhantomData,
         }
     }
+
+    // Creates an async receiver that implements `futures::Stream`.
+    pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> {
+        AsyncReceiver::new(self)
+    }
 }
 
 /// One direction socket that only supports sending.
@@ -191,3 +197,63 @@ impl<I: MsgOnSocket> MsgSender for Sender<I> {
 impl<O: MsgOnSocket> MsgReceiver for Receiver<O> {
     type M = O;
 }
+
+/// Asynchronous adaptor for `MsgSocket`.
+pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> {
+    inner: &'a mut MsgSocket<I, O>,
+    done: bool, // Have hit an error and the Stream should return null when polled.
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> {
+    fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> {
+        add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?;
+        Ok(AsyncReceiver {
+            inner: msg_socket,
+            done: false,
+        })
+    }
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> {
+    fn drop(&mut self) {
+        if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) {
+            error!(
+                "Failed to restore non-blocking behavior to message socket: {}",
+                e
+            );
+        }
+    }
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> {
+    type Item = MsgResult<O>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        if self.done {
+            return Poll::Ready(None);
+        }
+
+        let ret = match self.inner.recv() {
+            Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))),
+            Err(MsgError::Recv(e)) => {
+                if e.errno() == EWOULDBLOCK {
+                    add_read_waker(self.inner.as_raw_fd(), cx.waker().clone())
+                        .map(|_| Poll::Pending)
+                        .map_err(MsgError::AddingWaker)
+                } else {
+                    Err(MsgError::Recv(e))
+                }
+            }
+            Err(e) => Err(e),
+        };
+
+        match ret {
+            Ok(p) => p,
+            Err(e) => {
+                // Indicate something went wrong and no more events will be provided.
+                self.done = true;
+                Poll::Ready(Some(Err(e)))
+            }
+        }
+    }
+}
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
index 2924dc6..f03c36f 100644
--- a/msg_socket/src/msg_on_socket.rs
+++ b/msg_socket/src/msg_on_socket.rs
@@ -15,6 +15,8 @@ use sys_util::{Error as SysError, EventFd};
 #[derive(Debug, PartialEq)]
 /// An error during transaction or serialization/deserialization.
 pub enum MsgError {
+    /// Error adding a waker for async read.
+    AddingWaker(cros_async::fd_executor::Error),
     /// Error while sending a request or response.
     Send(SysError),
     /// Error while receiving a request or response.
@@ -28,6 +30,8 @@ pub enum MsgError {
     ExpectFd,
     /// There was some associated file descriptor received but not used when deserialize.
     NotExpectFd,
+    /// Failed to set flags on the file descriptor.
+    SettingFdFlags(SysError),
     /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens
     /// when max_fd_count() returns a value that is too small.
     WrongFdBufferSize,
@@ -43,6 +47,7 @@ impl Display for MsgError {
         use self::MsgError::*;
 
         match self {
+            AddingWaker(e) => write!(f, "failed to add a waker: {}", e),
             Send(e) => write!(f, "failed to send request or response: {}", e),
             Recv(e) => write!(f, "failed to receive request or response: {}", e),
             InvalidType => write!(f, "invalid type"),
@@ -53,6 +58,7 @@ impl Display for MsgError {
             ),
             ExpectFd => write!(f, "missing associated file descriptor for request"),
             NotExpectFd => write!(f, "unexpected file descriptor is unused"),
+            SettingFdFlags(e) => write!(f, "failed setting flags on the message FD: {}", e),
             WrongFdBufferSize => write!(f, "fd buffer size too small"),
             WrongMsgBufferSize => write!(f, "msg buffer size too small"),
         }