summary refs log tree commit diff
path: root/devices/src/virtio/wl.rs
diff options
context:
space:
mode:
Diffstat (limited to 'devices/src/virtio/wl.rs')
-rw-r--r--devices/src/virtio/wl.rs147
1 files changed, 122 insertions, 25 deletions
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 12f5012..7b93405 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -53,8 +53,9 @@ use data_model::VolatileMemoryError;
 use data_model::*;
 
 use msg_socket::{MsgError, MsgReceiver, MsgSender};
-use msg_socket2::de::VisitorWithFds;
+use msg_socket2::de::{EnumAccessWithFds, VariantAccessWithFds, VisitorWithFds};
 use msg_socket2::ser::SerializeAdapter;
+use msg_socket2::Deserialize;
 #[cfg(feature = "wl-dmabuf")]
 use resources::GpuMemoryDesc;
 #[cfg(feature = "wl-dmabuf")]
@@ -612,6 +613,12 @@ impl WlVfd {
         Ok(vfd)
     }
 
+    fn from_socket(socket: UnixStream) -> WlVfd {
+        let mut vfd = WlVfd::default();
+        vfd.socket = Some(socket);
+        vfd
+    }
+
     fn allocate(vm: VmRequester, size: u64) -> WlResult<WlVfd> {
         let size_page_aligned = round_up_to_page_size(size as usize) as u64;
         let mut vfd_shm = SharedMemory::named("virtwl_alloc").map_err(WlError::NewAlloc)?;
@@ -879,8 +886,79 @@ enum WlRecv {
     Hup,
 }
 
+#[derive(Debug)]
+enum WaylandSocket {
+    Listening(PathBuf),
+    NonListening(UnixStream),
+}
+
+impl SerializeWithFds for WaylandSocket {
+    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        use WaylandSocket::*;
+        match self {
+            Listening(path) => {
+                serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path)
+            }
+            NonListening(socket) => serializer.serialize_newtype_variant(
+                "WaylandSocket",
+                1,
+                "NonListening",
+                &SerializeAdapter::new(socket),
+            ),
+        }
+    }
+
+    fn serialize_fds<'fds, S: FdSerializer<'fds>>(
+        &'fds self,
+        serializer: S,
+    ) -> Result<S::Ok, S::Error> {
+        use WaylandSocket::*;
+        match self {
+            Listening(path) => {
+                serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path)
+            }
+            NonListening(socket) => {
+                serializer.serialize_newtype_variant("WaylandSocket", 1, "NonListening", socket)
+            }
+        }
+    }
+}
+
+impl<'de> DeserializeWithFds<'de> for WaylandSocket {
+    fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
+        struct Visitor;
+
+        impl<'de> VisitorWithFds<'de> for Visitor {
+            type Value = WaylandSocket;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "enum WaylandSocket")
+            }
+
+            fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
+            where
+                A: EnumAccessWithFds<'de>,
+            {
+                #[derive(Debug, Deserialize)]
+                enum Variant {
+                    Listening,
+                    NonListening,
+                }
+
+                match data.variant()? {
+                    (Variant::Listening, variant) => variant.newtype_variant(),
+
+                    (Variant::NonListening, variant) => variant.newtype_variant(),
+                }
+            }
+        }
+
+        deserializer.deserialize_enum("WaylandSocket", &["Listening", "NonListening"], Visitor)
+    }
+}
+
 struct WlState {
-    wayland_paths: Map<Vec<u8>, PathBuf>,
+    wayland_sockets: Map<Vec<u8>, WaylandSocket>,
     vm: VmRequester,
     resource_bridge: Option<ResourceRequestSocket>,
     use_transition_flags: bool,
@@ -895,13 +973,13 @@ struct WlState {
 
 impl WlState {
     fn new(
-        wayland_paths: Map<Vec<u8>, PathBuf>,
+        wayland_sockets: Map<Vec<u8>, WaylandSocket>,
         vm_socket: VmMemoryControlRequestSocket,
         use_transition_flags: bool,
         resource_bridge: Option<ResourceRequestSocket>,
     ) -> WlState {
         WlState {
-            wayland_paths,
+            wayland_sockets,
             vm: VmRequester::new(vm_socket),
             resource_bridge,
             poll_ctx: PollContext::new().expect("failed to create PollContext"),
@@ -1035,12 +1113,18 @@ impl WlState {
 
         match self.vfds.entry(id) {
             Entry::Vacant(entry) => {
-                let vfd = entry.insert(WlVfd::connect(
-                    &self
-                        .wayland_paths
-                        .get(name)
-                        .ok_or(WlError::UnknownSocketName(name.to_vec()))?,
-                )?);
+                let vfd =
+                    if let Some(WaylandSocket::Listening(path)) = self.wayland_sockets.get(name) {
+                        WlVfd::connect(path)?
+                    } else if let Some(WaylandSocket::NonListening(socket)) =
+                        self.wayland_sockets.remove(name)
+                    {
+                        WlVfd::from_socket(socket)
+                    } else {
+                        return Err(WlError::UnknownSocketName(name.to_vec()));
+                    };
+
+                let vfd = entry.insert(vfd);
                 self.poll_ctx
                     .add(vfd.poll_fd().unwrap(), id)
                     .map_err(WlError::PollContextAdd)?;
@@ -1056,16 +1140,17 @@ impl WlState {
         }
     }
 
-    fn add_path(&mut self, name: Vec<u8>, path: PathBuf) -> Result<(), Error> {
+    fn add_socket(&mut self, name: Vec<u8>, socket: UnixStream) -> Result<(), Error> {
         if name.len() > 32 {
             return Err(Error::new(libc::EINVAL));
         }
 
-        if self.wayland_paths.contains_key(&name) {
+        if self.wayland_sockets.contains_key(&name) {
             return Err(Error::new(libc::EADDRINUSE));
         }
 
-        self.wayland_paths.insert(name, path);
+        self.wayland_sockets
+            .insert(name, WaylandSocket::NonListening(socket));
 
         Ok(())
     }
@@ -1413,7 +1498,7 @@ impl WlState {
     }
 }
 
-pub struct Worker {
+struct Worker {
     interrupt: Interrupt,
     mem: GuestMemory,
     in_queue: Queue,
@@ -1423,12 +1508,12 @@ pub struct Worker {
 }
 
 impl Worker {
-    pub fn new(
+    fn new(
         mem: GuestMemory,
         interrupt: Interrupt,
         in_queue: Queue,
         out_queue: Queue,
-        wayland_paths: Map<Vec<u8>, PathBuf>,
+        wayland_sockets: Map<Vec<u8>, WaylandSocket>,
         vm_socket: VmMemoryControlRequestSocket,
         use_transition_flags: bool,
         resource_bridge: Option<ResourceRequestSocket>,
@@ -1440,7 +1525,7 @@ impl Worker {
             in_queue,
             out_queue,
             state: WlState::new(
-                wayland_paths,
+                wayland_sockets,
                 vm_socket,
                 use_transition_flags,
                 resource_bridge,
@@ -1565,14 +1650,14 @@ impl Worker {
                         }
                     }
                     Token::CommandSocket => {
-                        let resp = match dbg!(self.control_socket.recv()) {
-                            Ok(WlControlCommand::AddSocket { name, path }) => {
+                        let resp: WlControlResult = match self.control_socket.recv() {
+                            Ok(WlControlCommand::AddSocket { name, socket }) => {
                                 match unique_name(Cow::Owned(name), |name| !self
                                     .state
                                     .wayland_sockets
                                     .contains_key(name))
                                 {
-                                    Some(name) => match self.state.add_path(name.clone(), socket)
+                                    Some(name) => match self.state.add_socket(name.clone(), socket)
                                     {
                                         Ok(()) => WlControlResult::SocketAdded(name),
                                         Err(e) => WlControlResult::Err(e),
@@ -1653,7 +1738,7 @@ impl Worker {
 pub struct Wl {
     kill_evt: Option<EventFd>,
     worker_thread: Option<thread::JoinHandle<()>>,
-    wayland_paths: Map<Vec<u8>, PathBuf>,
+    wayland_sockets: Option<Map<Vec<u8>, WaylandSocket>>,
     vm_socket: Option<VmMemoryControlRequestSocket>,
     resource_bridge: Option<ResourceRequestSocket>,
     use_transition_flags: bool,
@@ -1752,7 +1837,7 @@ impl<'de> DeserializeWithFds<'de> for Params {
 
         deserializer.deserialize_struct(
             "Params",
-            &["wayland_paths", "vm_socket", "resource_bridge"],
+            &["wayland_sockets", "vm_socket", "resource_bridge"],
             Visitor,
         )
     }
@@ -1770,10 +1855,15 @@ impl VirtioDeviceNew for Wl {
             control_socket,
         } = params;
 
+        let wayland_sockets = wayland_paths
+            .into_iter()
+            .map(|(n, path)| (n, WaylandSocket::Listening(path)))
+            .collect();
+
         Ok(Self {
             kill_evt: None,
             worker_thread: None,
-            wayland_paths,
+            wayland_sockets: Some(wayland_sockets),
             vm_socket: Some(vm_socket),
             resource_bridge,
             use_transition_flags: false,
@@ -1799,6 +1889,13 @@ impl VirtioDevice for Wl {
     fn keep_fds(&self) -> Vec<RawFd> {
         let mut keep_fds = Vec::new();
 
+        if let Some(ref wayland_sockets) = self.wayland_sockets {
+            for (_, socket) in wayland_sockets {
+                if let WaylandSocket::NonListening(socket) = socket {
+                    keep_fds.push(socket.as_raw_fd());
+                }
+            }
+        }
         if let Some(vm_socket) = &self.vm_socket {
             keep_fds.push(vm_socket.as_raw_fd());
         }
@@ -1853,7 +1950,7 @@ impl VirtioDevice for Wl {
         self.kill_evt = Some(self_kill_evt);
 
         if let Some(vm_socket) = self.vm_socket.take() {
-            let wayland_paths = self.wayland_paths.clone();
+            let wayland_sockets = self.wayland_sockets.take().unwrap();
             let use_transition_flags = self.use_transition_flags;
             let resource_bridge = self.resource_bridge.take();
             let control_socket = self.control_socket.take().unwrap();
@@ -1867,7 +1964,7 @@ impl VirtioDevice for Wl {
                             interrupt,
                             queues.remove(0),
                             queues.remove(0),
-                            wayland_paths,
+                            wayland_sockets,
                             vm_socket,
                             use_transition_flags,
                             resource_bridge,