summary refs log tree commit diff
path: root/devices/src/virtio/wl.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/wl.rs')
-rw-r--r--devices/src/virtio/wl.rs73
1 files changed, 69 insertions, 4 deletions
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 6a412f9..0ebb4ef 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -70,7 +70,10 @@ use super::resource_bridge::*;
 use super::{
     DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_WL, VIRTIO_F_VERSION_1,
 };
-use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse};
+use vm_control::{
+    MaybeOwnedFd, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse,
+    WlControlCommand, WlControlResponseSocket, WlControlResult,
+};
 
 const VIRTWL_SEND_MAX_ALLOCS: usize = 28;
 const VIRTIO_WL_CMD_VFD_NEW: u32 = 256;
@@ -993,6 +996,20 @@ impl WlState {
         }
     }
 
+    fn add_path(&mut self, name: String, path: PathBuf) -> Result<(), Error> {
+        if name.bytes().len() > 32 {
+            return Err(Error::new(libc::EINVAL));
+        }
+
+        if self.wayland_paths.contains_key(&name) {
+            return Err(Error::new(libc::EADDRINUSE));
+        }
+
+        self.wayland_paths.insert(name, path);
+
+        Ok(())
+    }
+
     fn process_poll_context(&mut self) {
         let events = match self.poll_ctx.wait_timeout(Duration::from_secs(0)) {
             Ok(v) => v.to_owned(),
@@ -1343,6 +1360,7 @@ pub struct Worker {
     in_queue: Queue,
     out_queue: Queue,
     state: WlState,
+    control_socket: WlControlResponseSocket,
 }
 
 impl Worker {
@@ -1355,6 +1373,7 @@ impl Worker {
         vm_socket: VmMemoryControlRequestSocket,
         use_transition_flags: bool,
         resource_bridge: Option<ResourceRequestSocket>,
+        control_socket: WlControlResponseSocket,
     ) -> Worker {
         Worker {
             interrupt,
@@ -1367,6 +1386,7 @@ impl Worker {
                 use_transition_flags,
                 resource_bridge,
             ),
+            control_socket,
         }
     }
 
@@ -1379,6 +1399,7 @@ impl Worker {
         enum Token {
             InQueue,
             OutQueue,
+            CommandSocket,
             Kill,
             State,
             InterruptResample,
@@ -1387,6 +1408,7 @@ impl Worker {
         let poll_ctx: PollContext<Token> = match PollContext::build_with(&[
             (&in_queue_evt, Token::InQueue),
             (&out_queue_evt, Token::OutQueue),
+            (&self.control_socket, Token::CommandSocket),
             (&kill_evt, Token::Kill),
             (&self.state.poll_ctx, Token::State),
             (self.interrupt.get_resample_evt(), Token::InterruptResample),
@@ -1398,6 +1420,11 @@ impl Worker {
             }
         };
 
+        if let Err(e) = self.control_socket.send(&WlControlResult::Ready) {
+            error!("control socket failed to notify readiness: {}", e);
+            return;
+        }
+
         'poll: loop {
             let mut signal_used_in = false;
             let mut signal_used_out = false;
@@ -1475,6 +1502,25 @@ impl Worker {
                             }
                         }
                     }
+                    Token::CommandSocket => {
+                        let resp = match self.control_socket.recv() {
+                            Ok(WlControlCommand::AddSocket { name, path }) => {
+                                self.state.add_path(name, path).into()
+                            }
+                            Err(MsgError::InvalidData) => {
+                                WlControlResult::Err(Error::new(libc::EINVAL))
+                            }
+                            Err(e) => {
+                                error!("control socket failed recv: {}", e);
+                                break 'poll;
+                            }
+                        };
+
+                        if let Err(e) = self.control_socket.send(&resp) {
+                            error!("control socket failed send: {}", e);
+                            break 'poll;
+                        }
+                    }
                     Token::Kill => break 'poll,
                     Token::State => self.state.process_poll_context(),
                     Token::InterruptResample => {
@@ -1537,6 +1583,7 @@ pub struct Wl {
     vm_socket: Option<VmMemoryControlRequestSocket>,
     resource_bridge: Option<ResourceRequestSocket>,
     use_transition_flags: bool,
+    control_socket: Option<WlControlResponseSocket>,
 }
 
 use msg_socket2::{
@@ -1552,17 +1599,22 @@ pub struct Params {
     pub wayland_paths: Map<String, PathBuf>,
     pub vm_socket: VmMemoryControlRequestSocket,
     pub resource_bridge: Option<ResourceRequestSocket>,
+    pub control_socket: WlControlResponseSocket,
 }
 
 impl SerializeWithFds for Params {
     fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
-        let mut state = serializer.serialize_struct("Params", 3)?;
+        let mut state = serializer.serialize_struct("Params", 4)?;
         state.serialize_field("wayland_paths", &self.wayland_paths)?;
         state.serialize_field("vm_socket", &SerializeAdapter::new(&self.vm_socket))?;
         state.serialize_field(
             "resource_bridge",
             &SerializeAdapter::new(&self.resource_bridge),
         )?;
+        state.serialize_field(
+            "control_socket",
+            &SerializeAdapter::new(&self.control_socket),
+        )?;
         state.end()
     }
 
@@ -1570,10 +1622,11 @@ impl SerializeWithFds for Params {
     where
         S: FdSerializer<'fds>,
     {
-        let mut state = serializer.serialize_struct("Params", 3)?;
+        let mut state = serializer.serialize_struct("Params", 4)?;
         state.serialize_field("wayland_paths", &self.wayland_paths)?;
         state.serialize_field("vm_socket", &self.vm_socket)?;
         state.serialize_field("resource_bridge", &self.resource_bridge)?;
+        state.serialize_field("control_socket", &self.control_socket)?;
         state.end()
     }
 }
@@ -1599,7 +1652,7 @@ impl<'de> DeserializeWithFds<'de> for Params {
                 use serde::de::Error;
 
                 fn too_short_error<E: Error>(len: usize) -> E {
-                    E::invalid_length(len, &"struct Params with 3 elements")
+                    E::invalid_length(len, &"struct Params with 4 elements")
                 }
 
                 Ok(Params {
@@ -1614,6 +1667,11 @@ impl<'de> DeserializeWithFds<'de> for Params {
                         .next_element::<Option<_>>()?
                         .ok_or_else(|| too_short_error(2))?
                         .map(MsgSocket::new),
+
+                    control_socket: seq
+                        .next_element()?
+                        .map(MsgSocket::new)
+                        .ok_or_else(|| too_short_error(3))?,
                 })
             }
         }
@@ -1635,6 +1693,7 @@ impl VirtioDeviceNew for Wl {
             wayland_paths,
             vm_socket,
             resource_bridge,
+            control_socket,
         } = params;
 
         Ok(Self {
@@ -1644,6 +1703,7 @@ impl VirtioDeviceNew for Wl {
             vm_socket: Some(vm_socket),
             resource_bridge,
             use_transition_flags: false,
+            control_socket: Some(control_socket),
         })
     }
 }
@@ -1671,6 +1731,9 @@ impl VirtioDevice for Wl {
         if let Some(resource_bridge) = &self.resource_bridge {
             keep_fds.push(resource_bridge.as_raw_fd());
         }
+        if let Some(control_socket) = &self.control_socket {
+            keep_fds.push(control_socket.as_raw_fd());
+        }
 
         keep_fds
     }
@@ -1717,6 +1780,7 @@ impl VirtioDevice for Wl {
             let wayland_paths = self.wayland_paths.clone();
             let use_transition_flags = self.use_transition_flags;
             let resource_bridge = self.resource_bridge.take();
+            let control_socket = self.control_socket.take().unwrap();
             println!("creating worker");
             let worker_result =
                 thread::Builder::new()
@@ -1731,6 +1795,7 @@ impl VirtioDevice for Wl {
                             vm_socket,
                             use_transition_flags,
                             resource_bridge,
+                            control_socket,
                         )
                         .run(queue_evts, kill_evt);
                     });