summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-07-06 16:13:01 +0000
committerAlyssa Ross <hi@alyssa.is>2020-07-06 16:22:23 +0000
commitaf15d3fa14dd82f14f747de17c30184db133effd (patch)
tree563d08faefaad8e8a0c9767a0c74d89586a72c73
parent727757c19b6ca4c3f5c83717fcc9635a16526a46 (diff)
downloadcrosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar.gz
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar.bz2
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar.lz
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar.xz
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.tar.zst
crosvm-af15d3fa14dd82f14f747de17c30184db133effd.zip
devices: unique wl socket name generation
-rw-r--r--devices/src/virtio/wl.rs65
1 files changed, 64 insertions, 1 deletions
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 955ca76..9fe00ad 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -28,6 +28,7 @@
 //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
 //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
 
+use std::borrow::Cow;
 use std::collections::btree_map::Entry;
 use std::collections::{BTreeMap as Map, BTreeSet as Set, VecDeque};
 use std::convert::{From, Infallible};
@@ -135,6 +136,57 @@ ioctl_iow_nr!(DMA_BUF_IOCTL_SYNC, DMA_BUF_IOCTL_BASE, 0, dma_buf_sync);
 const VIRTIO_WL_CTRL_VFD_SEND_KIND_LOCAL: u32 = 0;
 const VIRTIO_WL_CTRL_VFD_SEND_KIND_VIRTGPU: u32 = 1;
 
+/// If `name` contains "%d", tries replacing "%d" with successive
+/// integers starting from 0, until name satisfies predicate.
+///
+/// `name` can only contain one "%", and it must be followed by "d".
+fn unique_name<'a, S, P>(name: S, predicate: P) -> Option<Vec<u8>>
+where
+    S: Into<Cow<'a, [u8]>>,
+    P: Fn(&[u8]) -> bool,
+{
+    let name = name.into();
+    if let Some(pos) = name.iter().position(|b| *b == b'%') {
+        if name.get(pos + 1) != Some(&b'd') {
+            None
+        } else if name[(pos + 1)..].contains(&b'%') {
+            None
+        } else {
+            let mut i = 0;
+            let mut resolved_name = Vec::with_capacity(name.len());
+            loop {
+                resolved_name.clear();
+                resolved_name.extend(&name[0..pos]);
+                resolved_name.extend(i.to_string().bytes());
+                resolved_name.extend(&name[(pos + 2)..]);
+
+                eprintln!("Trying {}", String::from_utf8_lossy(&resolved_name));
+
+                if predicate(&resolved_name) {
+                    break;
+                }
+
+                i += 1;
+            }
+            Some(resolved_name)
+        }
+    } else {
+        Some(name.into_owned())
+    }
+}
+
+#[test]
+fn test_unique_name() {
+    assert_eq!(unique_name(&b"%d"[..], |_| true).unwrap(), b"0");
+    assert_eq!(
+        unique_name(&b"-%d-"[..], |name| name != b"-0-").unwrap(),
+        b"-1-"
+    );
+    assert_eq!(unique_name(&b"%"[..], |_| true), None);
+    assert_eq!(unique_name(&b"%%"[..], |_| true), None);
+    assert_eq!(unique_name(&b"%d%d"[..], |_| true), None);
+}
+
 fn encode_vfd_new(
     writer: &mut Writer,
     resp: bool,
@@ -1506,7 +1558,18 @@ impl Worker {
                     Token::CommandSocket => {
                         let resp = match dbg!(self.control_socket.recv()) {
                             Ok(WlControlCommand::AddSocket { name, path }) => {
-                                self.state.add_path(name, path).into()
+                                match unique_name(Cow::Owned(name), |name| !self
+                                    .state
+                                    .wayland_sockets
+                                    .contains_key(name))
+                                {
+                                    Some(name) => match self.state.add_path(name.clone(), socket)
+                                    {
+                                        Ok(()) => WlControlResult::SocketAdded(name),
+                                        Err(e) => WlControlResult::Err(e),
+                                    },
+                                    None => WlControlResult::Err(Error::new(libc::EINVAL)),
+                                }
                             }
                             Err(MsgError::InvalidData) => {
                                 WlControlResult::Err(Error::new(libc::EINVAL))