summary refs log tree commit diff
path: root/msg_socket/src/lib.rs
diff options
context:
space:
mode:
authorDylan Reid <dgreid@chromium.org>2020-01-13 01:59:25 -0800
committerCommit Bot <commit-bot@chromium.org>2020-02-26 06:20:39 +0000
commit72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df (patch)
treed29b02393eedaedde8608ff7a6abc6efb3124c4e /msg_socket/src/lib.rs
parentdfd0139d7cf2935add342c76cec66702800e95b7 (diff)
downloadcrosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.gz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.bz2
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.lz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.xz
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.tar.zst
crosvm-72ccaefe0f384f708b3d2fd71aa3f3b40ab4e3df.zip
msg_socket: Add async receiving of messages
Add a member to MsgSocket that effectively returns an async iterator
over messages received on the socket. This is done by setting the socket
as non-blocking and registering with the async infrastructure when the
socket would block.

This feature will be used by devices that wish to handle messages in an
async fn context.

Change-Id: I47c6e83922068820cd19ffd9ef604ed8a16b755e
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/1997243
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Tested-by: Dylan Reid <dgreid@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Dylan Reid <dgreid@chromium.org>
Diffstat (limited to 'msg_socket/src/lib.rs')
-rw-r--r--msg_socket/src/lib.rs88
1 files changed, 77 insertions, 11 deletions
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
index 5b9f9ce..ea817f0 100644
--- a/msg_socket/src/lib.rs
+++ b/msg_socket/src/lib.rs
@@ -7,8 +7,17 @@ mod msg_on_socket;
 use std::io::Result;
 use std::marker::PhantomData;
 use std::os::unix::io::{AsRawFd, RawFd};
+use std::pin::Pin;
+use std::task::{Context, Poll};
 
-use sys_util::{handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket};
+use futures::Stream;
+use libc::{EWOULDBLOCK, O_NONBLOCK};
+
+use cros_async::fd_executor::add_read_waker;
+use sys_util::{
+    add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError,
+    ScmSocket,
+};
 
 pub use crate::msg_on_socket::*;
 pub use msg_on_socket_derive::*;
@@ -18,16 +27,8 @@ pub use msg_on_socket_derive::*;
 pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>(
 ) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> {
     let (sock1, sock2) = UnixSeqpacket::pair()?;
-    let requester = MsgSocket {
-        sock: sock1,
-        _i: PhantomData,
-        _o: PhantomData,
-    };
-    let responder = MsgSocket {
-        sock: sock2,
-        _i: PhantomData,
-        _o: PhantomData,
-    };
+    let requester = MsgSocket::new(sock1);
+    let responder = MsgSocket::new(sock2);
     Ok((requester, responder))
 }
 
@@ -47,6 +48,11 @@ impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
             _o: PhantomData,
         }
     }
+
+    // Creates an async receiver that implements `futures::Stream`.
+    pub fn async_receiver(&mut self) -> MsgResult<AsyncReceiver<I, O>> {
+        AsyncReceiver::new(self)
+    }
 }
 
 /// One direction socket that only supports sending.
@@ -191,3 +197,63 @@ impl<I: MsgOnSocket> MsgSender for Sender<I> {
 impl<O: MsgOnSocket> MsgReceiver for Receiver<O> {
     type M = O;
 }
+
+/// Asynchronous adaptor for `MsgSocket`.
+pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> {
+    inner: &'a mut MsgSocket<I, O>,
+    done: bool, // Have hit an error and the Stream should return null when polled.
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> {
+    fn new(msg_socket: &mut MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> {
+        add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?;
+        Ok(AsyncReceiver {
+            inner: msg_socket,
+            done: false,
+        })
+    }
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> {
+    fn drop(&mut self) {
+        if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) {
+            error!(
+                "Failed to restore non-blocking behavior to message socket: {}",
+                e
+            );
+        }
+    }
+}
+
+impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> {
+    type Item = MsgResult<O>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        if self.done {
+            return Poll::Ready(None);
+        }
+
+        let ret = match self.inner.recv() {
+            Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))),
+            Err(MsgError::Recv(e)) => {
+                if e.errno() == EWOULDBLOCK {
+                    add_read_waker(self.inner.as_raw_fd(), cx.waker().clone())
+                        .map(|_| Poll::Pending)
+                        .map_err(MsgError::AddingWaker)
+                } else {
+                    Err(MsgError::Recv(e))
+                }
+            }
+            Err(e) => Err(e),
+        };
+
+        match ret {
+            Ok(p) => p,
+            Err(e) => {
+                // Indicate something went wrong and no more events will be provided.
+                self.done = true;
+                Poll::Ready(Some(Err(e)))
+            }
+        }
+    }
+}