summary refs log tree commit diff
path: root/devices/src/virtio/wl.rs
diff options
context:
space:
mode:
authorZach Reizner <zachr@google.com>2018-04-06 12:12:12 -0700
committerchrome-bot <chrome-bot@chromium.org>2018-04-06 19:50:33 -0700
commitd86e698ec800db139edee03a45140078850abfad (patch)
treed04b83bbadee12d2123fac8619e75c4deb6f68cf /devices/src/virtio/wl.rs
parentc1b74eb8b1d123a940cabefc7be864cf33d74d00 (diff)
downloadcrosvm-d86e698ec800db139edee03a45140078850abfad.tar
crosvm-d86e698ec800db139edee03a45140078850abfad.tar.gz
crosvm-d86e698ec800db139edee03a45140078850abfad.tar.bz2
crosvm-d86e698ec800db139edee03a45140078850abfad.tar.lz
crosvm-d86e698ec800db139edee03a45140078850abfad.tar.xz
crosvm-d86e698ec800db139edee03a45140078850abfad.tar.zst
crosvm-d86e698ec800db139edee03a45140078850abfad.zip
devices: use nested PollContext in wayland device
The wl device was the last user of the old Poller.

BUG=chromium:816692
TEST=run wayland under crosvm

Change-Id: I6c1c1db2774a6e783b7bd1109288328d75ad2223
Reviewed-on: https://chromium-review.googlesource.com/1000102
Commit-Ready: Zach Reizner <zachr@chromium.org>
Tested-by: Zach Reizner <zachr@chromium.org>
Reviewed-by: Dylan Reid <dgreid@chromium.org>
Diffstat (limited to 'devices/src/virtio/wl.rs')
-rw-r--r--devices/src/virtio/wl.rs165
1 files changed, 92 insertions, 73 deletions
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 00b3a16..32ef2f0 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -46,12 +46,13 @@ use std::result;
 use std::sync::Arc;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::thread;
+use std::time::Duration;
 
 use data_model::*;
 use data_model::VolatileMemoryError;
 
-use sys_util::{Error, Result, EventFd, Poller, Pollable, Scm, SharedMemory, GuestAddress,
-               GuestMemory, GuestMemoryError, FileFlags, pipe};
+use sys_util::{Error, Result, EventFd, Scm, SharedMemory, GuestAddress, GuestMemory,
+               GuestMemoryError, PollContext, PollToken, FileFlags, pipe};
 
 use vm_control::{VmControlError, VmRequest, VmResponse, MaybeOwnedFd};
 use super::{VirtioDevice, Queue, DescriptorChain, INTERRUPT_STATUS_USED_RING, TYPE_WL};
@@ -78,11 +79,6 @@ const VIRTIO_WL_VFD_MAP: u32 = 0x2;
 const VIRTIO_WL_VFD_CONTROL: u32 = 0x4;
 const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01;
 
-const Q_IN: u32 = 0;
-const Q_OUT: u32 = 1;
-const KILL: u32 = 2;
-const VFD_BASE_TOKEN: u32 = 0x100;
-
 const QUEUE_SIZE: u16 = 16;
 const QUEUE_SIZES: &'static [u16] = &[QUEUE_SIZE, QUEUE_SIZE];
 
@@ -278,6 +274,7 @@ enum WlError {
     WritePipe(io::Error),
     RecvVfd(Error),
     ReadPipe(io::Error),
+    PollContextAdd(Error),
 }
 
 impl fmt::Display for WlError {
@@ -303,6 +300,7 @@ impl error::Error for WlError {
             WlError::WritePipe(_) => "Failed to write to a pipe",
             WlError::RecvVfd(_) => "Failed to recv on a socket",
             WlError::ReadPipe(_) => "Failed to read a pipe",
+            WlError::PollContextAdd(_) => "Failed to listen to FD on poll context",
         }
     }
 }
@@ -591,7 +589,7 @@ impl WlVfd {
     }
 
     // The FD that gets sent if this VFD is sent over a socket.
-    fn fd(&self) -> Option<RawFd> {
+    fn send_fd(&self) -> Option<RawFd> {
         self.guest_shared_memory
             .as_ref()
             .map(|&(_, ref fd)| fd.as_raw_fd())
@@ -599,6 +597,17 @@ impl WlVfd {
             .or(self.remote_pipe.as_ref().map(|p| p.as_raw_fd()))
     }
 
+    // The FD that is used for polling for events on this VFD.
+    fn poll_fd(&self) -> Option<&AsRawFd> {
+        self.socket
+            .as_ref()
+            .map(|s| s as &AsRawFd)
+            .or(self.local_pipe
+                    .as_ref()
+                    .map(|&(_, ref p)| p as &AsRawFd))
+
+    }
+
     // Sends data/files from the guest to the host over this VFD.
     fn send(&mut self, scm: &mut Scm, fds: &[RawFd], data: VolatileSlice) -> WlResult<WlResp> {
         if let Some(ref socket) = self.socket {
@@ -687,6 +696,7 @@ struct WlState {
     wayland_path: PathBuf,
     vm: VmRequester,
     use_transition_flags: bool,
+    poll_ctx: PollContext<u32>,
     vfds: Map<u32, WlVfd>,
     next_vfd_id: u32,
     scm: Scm,
@@ -701,6 +711,7 @@ impl WlState {
         WlState {
             wayland_path: wayland_path,
             vm: VmRequester::new(vm_socket),
+            poll_ctx: PollContext::new().expect("failed to create PollContext"),
             use_transition_flags,
             scm: Scm::new(VIRTWL_SEND_MAX_ALLOCS),
             vfds: Map::new(),
@@ -734,6 +745,9 @@ impl WlState {
                 } else {
                     return Ok(WlResp::InvalidFlags);
                 };
+                self.poll_ctx
+                    .add(vfd.poll_fd().unwrap(), id)
+                    .map_err(WlError::PollContextAdd)?;
                 let resp = WlResp::VfdNew {
                     id: id,
                     flags: 0,
@@ -791,7 +805,10 @@ impl WlState {
 
         match self.vfds.entry(id) {
             Entry::Vacant(entry) => {
-                entry.insert(WlVfd::connect(&self.wayland_path)?);
+                let vfd = entry.insert(WlVfd::connect(&self.wayland_path)?);
+                self.poll_ctx
+                    .add(vfd.poll_fd().unwrap(), id)
+                    .map_err(WlError::PollContextAdd)?;
                 Ok(WlResp::VfdNew {
                        id: id,
                        flags,
@@ -804,6 +821,30 @@ impl WlState {
         }
     }
 
+    fn process_poll_context(&mut self) {
+        let events = match self.poll_ctx.wait_timeout(Duration::from_secs(0)) {
+            Ok(v) => v.to_owned(),
+            Err(e) => {
+                error!("failed polling for vfd evens: {:?}", e);
+                return;
+            }
+        };
+
+        for event in events.as_ref().iter_readable() {
+            if let Err(e) = self.recv(event.token()) {
+                error!("failed to recv from vfd: {:?}", e)
+            }
+        }
+
+        for event in events.as_ref().iter_hungup() {
+            if !event.readable() {
+                if let Err(e) = self.close(event.token()) {
+                    warn!("failed to close vfd: {:?}", e)
+                }
+            }
+        }
+    }
+
     fn close(&mut self, vfd_id: u32) -> WlResult<WlResp> {
         let mut to_delete = Set::new();
         for &(dest_vfd_id, ref q) in self.in_queue.iter() {
@@ -835,7 +876,7 @@ impl WlState {
         for (&id, fd) in vfd_ids[..vfd_count].iter().zip(fds.iter_mut()) {
             match self.vfds.get(&id.into()) {
                 Some(vfd) => {
-                    match vfd.fd() {
+                    match vfd.send_fd() {
                         Some(vfd_fd) => *fd = vfd_fd,
                         None => return Ok(WlResp::InvalidType),
                     }
@@ -873,8 +914,13 @@ impl WlState {
             return Ok(());
         }
         for file in self.in_file_queue.drain(..) {
-            self.vfds
-                .insert(self.next_vfd_id, WlVfd::from_file(self.vm.clone(), file)?);
+            let vfd = WlVfd::from_file(self.vm.clone(), file)?;
+            if let Some(poll_fd) = vfd.poll_fd() {
+                self.poll_ctx
+                    .add(poll_fd, self.next_vfd_id)
+                    .map_err(WlError::PollContextAdd)?;
+            }
+            self.vfds.insert(self.next_vfd_id, vfd);
             self.in_queue
                 .push_back((vfd_id, WlRecv::Vfd { id: self.next_vfd_id }));
             self.next_vfd_id += 1;
@@ -994,27 +1040,6 @@ impl WlState {
         }
         self.in_queue.pop_front();
     }
-
-    fn iter_sockets<'a, F>(&'a self, mut f: F)
-        where F: FnMut(u32, &'a UnixStream)
-    {
-        for (id, socket) in self.vfds
-                .iter()
-                .filter_map(|(&k, v)| v.socket.as_ref().map(|s| (k, s))) {
-            f(id, socket);
-        }
-    }
-
-    fn iter_pipes<'a, F>(&'a self, mut f: F)
-        where F: FnMut(u32, &'a File)
-    {
-        for (id, local_pipe) in
-            self.vfds
-                .iter()
-                .filter_map(|(&k, v)| v.local_pipe.as_ref().map(|&(_, ref p)| (k, p))) {
-            f(id, local_pipe);
-        }
-    }
 }
 
 struct Worker {
@@ -1057,39 +1082,40 @@ impl Worker {
     fn run(&mut self, mut queue_evts: Vec<EventFd>, kill_evt: EventFd) {
         let in_queue_evt = queue_evts.remove(0);
         let out_queue_evt = queue_evts.remove(0);
-        let mut token_vfd_id_map = Map::new();
-        let mut poller = Poller::new(3);
-        'poll: loop {
-            let tokens = {
-                // TODO(zachr): somehow keep pollables from allocating every loop
-                // The capacity is always the 3 static eventfds plus the number of vfd sockets. To
-                // estimate the number of vfd sockets, we use the previous poll's vfd id map size,
-                // which was equal to the number of vfd sockets.
-                let mut pollables = Vec::with_capacity(3 + token_vfd_id_map.len());
-                pollables.push((Q_IN, &in_queue_evt as &Pollable));
-                pollables.push((Q_OUT, &out_queue_evt as &Pollable));
-                pollables.push((KILL, &kill_evt as &Pollable));
-                token_vfd_id_map.clear();
-                // TODO(zachr): leave these out if there is no Q_IN to use
-                self.state
-                    .iter_sockets(|id, socket| {
-                                      let token = VFD_BASE_TOKEN + token_vfd_id_map.len() as u32;
-                                      token_vfd_id_map.insert(token, id);
-                                      pollables.push((token, socket));
-                                  });
-                self.state
-                    .iter_pipes(|id, pipe| {
-                                    let token = VFD_BASE_TOKEN + token_vfd_id_map.len() as u32;
-                                    token_vfd_id_map.insert(token, id);
-                                    pollables.push((token, pipe));
-                                });
-                poller.poll(&pollables[..]).expect("error: failed poll")
+        #[derive(PollToken)]
+        enum Token {
+            InQueue,
+            OutQueue,
+            Kill,
+            State,
+        }
+
+        let poll_ctx: PollContext<Token> =
+            match PollContext::new()
+                      .and_then(|pc| pc.add(&in_queue_evt, Token::InQueue).and(Ok(pc)))
+                      .and_then(|pc| pc.add(&out_queue_evt, Token::OutQueue).and(Ok(pc)))
+                      .and_then(|pc| pc.add(&kill_evt, Token::Kill).and(Ok(pc)))
+                      .and_then(|pc| pc.add(&self.state.poll_ctx, Token::State).and(Ok(pc))) {
+                Ok(pc) => pc,
+                Err(e) => {
+                    error!("failed creating PollContext: {:?}", e);
+                    return;
+                }
             };
 
+        'poll: loop {
             let mut signal_used = false;
-            for &token in tokens {
-                match token {
-                    Q_IN => {
+            let events = match poll_ctx.wait() {
+                Ok(v) => v,
+                Err(e) => {
+                    error!("failed polling for events: {:?}", e);
+                    break;
+                }
+            };
+
+            for event in events.iter() {
+                match event.token() {
+                    Token::InQueue => {
                         let _ = in_queue_evt.read();
                         // Used to buffer descriptor indexes that are invalid for our uses.
                         let mut rejects = [0u16; QUEUE_SIZE as usize];
@@ -1114,7 +1140,7 @@ impl Worker {
                             self.in_queue.add_used(&self.mem, reject, 0);
                         }
                     }
-                    Q_OUT => {
+                    Token::OutQueue => {
                         let _ = out_queue_evt.read();
                         // Used to buffer filled in descriptors that will be added to the used queue
                         // after iterating the available queue.
@@ -1159,15 +1185,8 @@ impl Worker {
                             self.out_queue.add_used(&self.mem, index, len);
                         }
                     }
-                    KILL => break 'poll,
-                    v => {
-                        if let Some(&id) = token_vfd_id_map.get(&v) {
-                            let res = self.state.recv(id);
-                            if let Err(e) = res {
-                                error!("failed to receive vfd {}: {:?}", id, e);
-                            }
-                        }
-                    }
+                    Token::Kill => break 'poll,
+                    Token::State => self.state.process_poll_context(),
                 }
             }