summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-26 11:54:48 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:22 +0000
commit98d69a42870030ad533dd8eda5da817430c2b71c (patch)
tree2f75316c682b4d0588c7c87faf142c064fb37f3c
parent353b1d9091b9095282463f36e26643506e2d2897 (diff)
downloadcrosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar.gz
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar.bz2
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar.lz
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar.xz
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.tar.zst
crosvm-98d69a42870030ad533dd8eda5da817430c2b71c.zip
send wl::Params over socket
-rw-r--r--devices/src/virtio/controller.rs33
-rw-r--r--devices/src/virtio/wl.rs1
-rw-r--r--msg_socket/src/lib.rs12
-rw-r--r--msg_socket2/src/de.rs72
-rw-r--r--msg_socket2/src/ser.rs2
-rw-r--r--msg_socket2/tests/option.rs12
-rw-r--r--src/linux.rs10
-rw-r--r--src/wl.rs21
8 files changed, 120 insertions, 43 deletions
diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs
index ba2543f..4930e44 100644
--- a/devices/src/virtio/controller.rs
+++ b/devices/src/virtio/controller.rs
@@ -28,29 +28,26 @@
 //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
 //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
 
-use std::collections::BTreeMap as Map;
 use std::fmt::{self, Formatter};
 use std::os::unix::io::{AsRawFd, RawFd};
-use std::path::PathBuf;
 use std::sync::Arc;
 use std::thread;
 
 use msg_socket::{MsgReceiver, MsgSocket};
 use msg_socket2::de::{EnumAccessWithFds, SeqAccessWithFds, VariantAccessWithFds, VisitorWithFds};
-use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds};
+use msg_socket2::ser::{SerializeAdapter, SerializeStructVariantFds, SerializeTupleVariantFds};
 use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds};
 use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer};
 use serde::{Deserialize, Serialize};
 use sys_util::net::UnixSeqpacket;
 use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory};
 
-use super::resource_bridge::*;
-use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice};
+use super::{Interrupt, InterruptProxyEvent, Params, Queue, VirtioDevice};
 use crate::{
     pci::{PciAddress, PciBarConfiguration, PciCapability, PciCapabilityID},
     MemoryParams,
 };
-use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket};
+use vm_control::MaybeOwnedFd;
 
 // As far as I can tell, these never change on the other side, so it's
 // fine to just copy them over.
@@ -82,9 +79,7 @@ impl PciCapability for RemotePciCapability {
 #[derive(Debug)]
 pub enum Request {
     Create {
-        // wayland_paths: Map<String, PathBuf>,
-        vm_socket: MaybeOwnedFd<UnixSeqpacket>,
-        // resource_bridge: Option<ResourceRequestSocket>,
+        device_params: Params,
         memory_params: MemoryParams,
     },
 
@@ -130,11 +125,11 @@ impl SerializeWithFds for Request {
 
         match self {
             Create {
-                vm_socket: _,
+                device_params,
                 memory_params,
             } => {
                 let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?;
-                sv.serialize_field("vm_socket", &())?;
+                sv.serialize_field("device_params", &SerializeAdapter::new(device_params))?;
                 sv.serialize_field("memory_params", memory_params)?;
                 sv.end()
             }
@@ -209,11 +204,11 @@ impl SerializeWithFds for Request {
 
         match self {
             Create {
-                vm_socket,
+                device_params,
                 memory_params,
             } => {
                 let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?;
-                sv.serialize_field("vm_socket", vm_socket)?;
+                sv.serialize_field("device_params", device_params)?;
                 sv.serialize_field("memory_params", memory_params)?;
                 sv.end()
             }
@@ -336,7 +331,9 @@ impl<'de> DeserializeWithFds<'de> for Request {
                                 }
 
                                 Ok(Request::Create {
-                                    vm_socket: seq.next_element()?.ok_or_else(|| too_short(0))?,
+                                    device_params: seq
+                                        .next_element()?
+                                        .ok_or_else(|| too_short(0))?,
                                     memory_params: seq
                                         .next_element()?
                                         .ok_or_else(|| too_short(1))?,
@@ -685,16 +682,12 @@ pub struct Controller {
 
 impl Controller {
     pub fn create(
-        wayland_paths: Map<String, PathBuf>,
-        vm_socket: VmMemoryControlRequestSocket,
-        resource_bridge: Option<ResourceRequestSocket>,
+        device_params: Params,
         memory_params: MemoryParams,
         socket: Socket,
     ) -> Result<Controller, msg_socket2::Error> {
         socket.send(Request::Create {
-            // wayland_paths,
-            vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket),
-            // resource_bridge,
+            device_params,
             memory_params,
         })?;
 
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 5a3505b..6671693 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -1547,6 +1547,7 @@ use std::fmt::Formatter;
 
 use super::VirtioDeviceNew;
 
+#[derive(Debug)]
 pub struct Params {
     pub wayland_paths: Map<String, PathBuf>,
     pub vm_socket: VmMemoryControlRequestSocket,
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
index 5dedac8..f9b1ca0 100644
--- a/msg_socket/src/lib.rs
+++ b/msg_socket/src/lib.rs
@@ -4,6 +4,7 @@
 
 mod msg_on_socket;
 
+use std::fmt::{self, Debug, Formatter};
 use std::io::{IoSlice, Result};
 use std::marker::PhantomData;
 use std::os::unix::prelude::*;
@@ -34,7 +35,7 @@ pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>(
 }
 
 /// Bidirection sock that support both send and recv.
-#[derive(SerializeWithFds, DeserializeWithFds)]
+#[derive(DeserializeWithFds, SerializeWithFds)]
 #[msg_socket2(strategy = "AsRawFd")]
 pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> {
     sock: UnixSeqpacket,
@@ -42,6 +43,15 @@ pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> {
     _o: PhantomData<O>,
 }
 
+// Implement Debug manually because the derivable implementation only
+// works when I and O are Debug, even though they're only used as
+// PhantomData type parameters.
+impl<I: MsgOnSocket, O: MsgOnSocket> Debug for MsgSocket<I, O> {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(f, "MsgSocket {{ sock: {:?}, .. }}", self.sock)
+    }
+}
+
 impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
     // Create a new MsgSocket.
     pub fn new(s: UnixSeqpacket) -> MsgSocket<I, O> {
diff --git a/msg_socket2/src/de.rs b/msg_socket2/src/de.rs
index fa6582a..67073c0 100644
--- a/msg_socket2/src/de.rs
+++ b/msg_socket2/src/de.rs
@@ -17,7 +17,7 @@ use std::fmt::{self, Formatter};
 use std::marker::PhantomData;
 
 pub use serde::de::{
-    Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, SeqAccess, StdError,
+    Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, MapAccess, SeqAccess, StdError,
     VariantAccess, Visitor,
 };
 
@@ -352,7 +352,7 @@ pub trait MapAccessWithFds<'de>: Sized {
 
     /// Like `MapAccess::next_value_seed`, but `seed` is a
     /// `DeserializeWithFdsSeed` instead of a `DeserializeSeed`.
-    fn next_value_seed<V>(&mut self, seed: V) -> Result<Option<V::Value>, Self::Error>
+    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
     where
         V: DeserializeWithFdsSeed<'de>;
 
@@ -362,7 +362,15 @@ pub trait MapAccessWithFds<'de>: Sized {
         &mut self,
         kseed: K,
         vseed: V,
-    ) -> Result<Option<(K::Value, V::Value)>, Self::Error>;
+    ) -> Result<Option<(K::Value, V::Value)>, Self::Error> {
+        match self.next_key_seed(kseed)? {
+            Some(key) => {
+                let value = self.next_value_seed(vseed)?;
+                Ok(Some((key, value)))
+            }
+            None => Ok(None),
+        }
+    }
 
     /// Like `MapAccess::next_key`, but returns a `DeserializeWithFds`
     /// instead of a `Deserialize`.
@@ -400,6 +408,46 @@ pub trait MapAccessWithFds<'de>: Sized {
     fn invite<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error>;
 }
 
+impl<'de, 'fds, A, F> MapAccessWithFds<'de> for WithFds<'fds, A, F>
+where
+    A: MapAccess<'de>,
+    F: Iterator<Item = Fd>,
+{
+    type Error = A::Error;
+
+    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
+    where
+        K: DeserializeWithFdsSeed<'de>,
+    {
+        let wrapper = WithFds {
+            inner: seed,
+            fds: self.fds,
+        };
+
+        self.inner.next_key_seed(wrapper)
+    }
+
+    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
+    where
+        V: DeserializeWithFdsSeed<'de>,
+    {
+        let wrapper = WithFds {
+            inner: seed,
+            fds: self.fds,
+        };
+
+        self.inner.next_value_seed(wrapper)
+    }
+
+    fn size_hint(&self) -> Option<usize> {
+        self.inner.size_hint()
+    }
+
+    fn invite<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
+        visitor.visit_map(self.inner)
+    }
+}
+
 /// Like `EnumAccess`, but variants provide access to file descriptors
 /// as well as data.
 pub trait EnumAccessWithFds<'de>: Sized {
@@ -578,6 +626,17 @@ where
         self.inner.expecting(f)
     }
 
+    fn visit_none<E: Error>(self) -> Result<Self::Value, E> {
+        self.inner.visit_none()
+    }
+
+    fn visit_some<D: Deserializer<'de>>(self, deserializer: D) -> Result<Self::Value, D::Error> {
+        self.inner.visit_some(WithFds {
+            inner: deserializer,
+            fds: self.fds,
+        })
+    }
+
     fn visit_seq<A: SeqAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
         self.inner.visit_seq(WithFds {
             inner: data,
@@ -585,6 +644,13 @@ where
         })
     }
 
+    fn visit_map<A: MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
+        self.inner.visit_map(WithFds {
+            inner: map,
+            fds: self.fds,
+        })
+    }
+
     fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> {
         self.inner.visit_enum(WithFds {
             inner: data,
diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs
index 1391d12..8be85aa 100644
--- a/msg_socket2/src/ser.rs
+++ b/msg_socket2/src/ser.rs
@@ -555,6 +555,8 @@ macro_rules! fd_impl {
     };
 }
 
+serialize_impl!(());
+
 serialize_impl!(u8);
 serialize_impl!(u16);
 serialize_impl!(u32);
diff --git a/msg_socket2/tests/option.rs b/msg_socket2/tests/option.rs
new file mode 100644
index 0000000..130dd7d
--- /dev/null
+++ b/msg_socket2/tests/option.rs
@@ -0,0 +1,12 @@
+use msg_socket2::Socket;
+use sys_util::net::UnixSeqpacket;
+
+#[test]
+fn option() {
+    let (f1, f2) = UnixSeqpacket::pair().unwrap();
+    let s1: Socket<_, ()> = Socket::new(f1);
+    let s2: Socket<(), Option<String>> = Socket::new(f2);
+
+    s1.send(Some("hello world".to_string())).unwrap();
+    assert_eq!(s2.recv().unwrap(), Some("hello world".to_string()));
+}
diff --git a/src/linux.rs b/src/linux.rs
index 7bd4679..a416a3e 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -31,7 +31,7 @@ use acpi_tables::sdt::SDT;
 
 #[cfg(feature = "gpu")]
 use devices::virtio::EventDevice;
-use devices::virtio::{self, Console, VirtioDevice};
+use devices::virtio::{self, Console, Params, VirtioDevice};
 use devices::{
     self, Ac97Backend, Ac97Dev, Bus, HostBackendDeviceProvider, PciDevice, VfioContainer,
     VfioDevice, VfioPciDevice, VirtioPciDevice, XhciController,
@@ -772,9 +772,11 @@ fn create_wayland_device(
     let seq_socket = UnixSeqpacket::connect(&path).expect("connect failed");
     let msg_socket = msg_socket2::Socket::new(seq_socket);
     let dev = virtio::Controller::create(
-        cfg.wayland_socket_paths.clone(),
-        socket,
-        resource_bridge,
+        Params {
+            wayland_paths: cfg.wayland_socket_paths.clone(),
+            vm_socket: socket,
+            resource_bridge,
+        },
         memory_params,
         msg_socket,
     )
diff --git a/src/wl.rs b/src/wl.rs
index 45dfbee..8ae8856 100644
--- a/src/wl.rs
+++ b/src/wl.rs
@@ -1,11 +1,10 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
 use devices::virtio::{
-    InterruptProxy, InterruptProxyEvent, Params, RemotePciCapability, Request, Response,
-    VirtioDevice, VirtioDeviceNew, Wl,
+    InterruptProxy, InterruptProxyEvent, RemotePciCapability, Request, Response, VirtioDevice,
+    VirtioDeviceNew, Wl,
 };
 use msg_socket::MsgSocket;
-use std::collections::BTreeMap;
 use std::fs::remove_file;
 use sys_util::{error, net::UnixSeqpacketListener, GuestMemory};
 
@@ -30,11 +29,11 @@ fn main() {
     let conn = server.accept().expect("accept failed");
     let msg_socket: Socket = msg_socket2::Socket::new(conn);
 
-    let (vm_socket, memory_params) = match msg_socket.recv() {
+    let (device_params, memory_params) = match msg_socket.recv() {
         Ok(Request::Create {
-            vm_socket,
+            device_params,
             memory_params,
-        }) => (MsgSocket::new(vm_socket.owned()), memory_params),
+        }) => (device_params, memory_params),
 
         Ok(msg) => {
             panic!("received unexpected message: {:?}", msg);
@@ -45,15 +44,7 @@ fn main() {
         }
     };
 
-    let mut wayland_paths = BTreeMap::new();
-    wayland_paths.insert("".into(), "/run/user/1000/wayland-0".into());
-
-    let mut wl = Wl::new(Params {
-        wayland_paths,
-        vm_socket,
-        resource_bridge: None,
-    })
-    .unwrap();
+    let mut wl = Wl::new(device_params).unwrap();
 
     loop {
         match msg_socket.recv() {