summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlyssa Ross <hi@alyssa.is>2020-03-25 09:20:05 +0000
committerAlyssa Ross <hi@alyssa.is>2020-06-15 09:37:15 +0000
commit616c8e59099101ec20eaf8633be929164a68dee4 (patch)
tree4fe8bbca411c0d81fbe37af295ce32620d4cf73f
parentb6549a605935e29ab0ae4291737f8b0158bca1fb (diff)
downloadcrosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar.gz
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar.bz2
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar.lz
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar.xz
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.tar.zst
crosvm-616c8e59099101ec20eaf8633be929164a68dee4.zip
devices: VirtioDeviceNew
-rw-r--r--devices/src/virtio/virtio_device.rs8
-rw-r--r--devices/src/virtio/wl.rs121
-rw-r--r--src/wl.rs10
3 files changed, 127 insertions, 12 deletions
diff --git a/devices/src/virtio/virtio_device.rs b/devices/src/virtio/virtio_device.rs
index 58b6886..1d6e4bc 100644
--- a/devices/src/virtio/virtio_device.rs
+++ b/devices/src/virtio/virtio_device.rs
@@ -4,11 +4,19 @@
 
 use std::os::unix::io::RawFd;
 
+use msg_socket2::SerializeWithFds;
 use sys_util::{EventFd, GuestMemory};
 
 use super::*;
 use crate::pci::{MsixStatus, PciAddress, PciBarConfiguration, PciCapability};
 
+pub trait VirtioDeviceNew: Sized {
+    type Params: SerializeWithFds;
+    type Error;
+
+    fn new(params: Self::Params) -> Result<Self, Self::Error>;
+}
+
 /// Trait for virtio devices to be driven by a virtio transport.
 ///
 /// The lifecycle of a virtio device is to be moved to a virtio transport, which will then query the
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 2c94d3c..7a436f5 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -42,7 +42,6 @@ use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::os::unix::net::UnixStream;
 use std::path::{Path, PathBuf};
 use std::rc::Rc;
-use std::result;
 use std::thread;
 use std::time::Duration;
 
@@ -53,13 +52,14 @@ use data_model::VolatileMemoryError;
 use data_model::*;
 
 use msg_socket::{MsgError, MsgReceiver, MsgSender};
+use msg_socket2::de::VisitorWithFds;
 #[cfg(feature = "wl-dmabuf")]
 use resources::GpuMemoryDesc;
 #[cfg(feature = "wl-dmabuf")]
 use sys_util::ioctl_iow_nr;
 use sys_util::{
     error, pipe, round_up_to_page_size, warn, Error, EventFd, FileFlags, GuestMemory,
-    GuestMemoryError, PollContext, PollToken, Result, ScmSocket, SharedMemory,
+    GuestMemoryError, PollContext, PollToken, ScmSocket, SharedMemory,
 };
 
 #[cfg(feature = "wl-dmabuf")]
@@ -310,7 +310,7 @@ impl Display for WlError {
 
 impl std::error::Error for WlError {}
 
-type WlResult<T> = result::Result<T, WlError>;
+type WlResult<T> = Result<T, WlError>;
 
 impl From<GuestMemoryError> for WlError {
     fn from(e: GuestMemoryError) -> WlError {
@@ -1538,13 +1538,114 @@ pub struct Wl {
     use_transition_flags: bool,
 }
 
-impl Wl {
-    pub fn new(
-        wayland_paths: Map<String, PathBuf>,
-        vm_socket: VmMemoryControlRequestSocket,
-        resource_bridge: Option<ResourceRequestSocket>,
-    ) -> Result<Wl> {
-        Ok(Wl {
+use msg_socket2::{
+    de::{DeserializeWithFds, DeserializerWithFds},
+    ser::{
+        FdSerializer, SerializeRawFd, SerializeStruct, SerializeStructFds, SerializeWithFds,
+        Serializer,
+    },
+};
+use std::fmt::Formatter;
+
+use super::VirtioDeviceNew;
+
+pub struct Params {
+    pub wayland_paths: Map<String, PathBuf>,
+    pub vm_socket: VmMemoryControlRequestSocket,
+    pub resource_bridge: Option<ResourceRequestSocket>,
+}
+
+impl SerializeWithFds for Params {
+    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        let mut state = serializer.serialize_struct("Params", 3)?;
+        state.serialize_field("wayland_paths", &self.wayland_paths)?;
+        state.serialize_field("vm_socket", &())?;
+
+        let resource_bridge_marker = self.resource_bridge.as_ref().map(|_| ());
+        state.serialize_field("resource_bridge", &resource_bridge_marker)?;
+
+        state.end()
+    }
+
+    fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        let mut state = serializer.serialize_struct("Params", 3)?;
+
+        state.serialize_field("wayland_paths", &self.wayland_paths)?;
+        state.serialize_field(
+            "vm_socket",
+            &SerializeRawFd::new(&self.vm_socket.as_raw_fd()),
+        )?;
+
+        let bridge_fd = self.resource_bridge.as_ref().map(AsRawFd::as_raw_fd);
+        state.serialize_field(
+            "resource_bridge",
+            &bridge_fd.as_ref().map(SerializeRawFd::new),
+        )?;
+
+        state.end()
+    }
+}
+
+use msg_socket::MsgSocket;
+use msg_socket2::de::SeqAccessWithFds;
+
+impl<'de> DeserializeWithFds<'de> for Params {
+    fn deserialize<D: DeserializerWithFds<'de>>(deserializer: D) -> Result<Self, D::Error> {
+        struct Visitor;
+
+        impl<'de> VisitorWithFds<'de> for Visitor {
+            type Value = Params;
+
+            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
+                write!(f, "struct Params")
+            }
+
+            fn visit_seq<A: SeqAccessWithFds<'de>>(
+                self,
+                mut seq: A,
+            ) -> Result<Self::Value, A::Error> {
+                use serde::de::Error;
+
+                fn too_short_error<E: Error>(len: usize) -> E {
+                    E::invalid_length(len, &"struct Params with 3 elements")
+                }
+
+                Ok(Params {
+                    wayland_paths: seq.next_element()?.ok_or_else(|| too_short_error(0))?,
+
+                    vm_socket: seq
+                        .next_element()?
+                        .map(MsgSocket::new)
+                        .ok_or_else(|| too_short_error(1))?,
+
+                    resource_bridge: seq
+                        .next_element::<Option<_>>()?
+                        .ok_or_else(|| too_short_error(2))?
+                        .map(MsgSocket::new),
+                })
+            }
+        }
+
+        deserializer.deserialize_struct(
+            "Params",
+            &["wayland_paths", "vm_socket", "resource_bridge"],
+            Visitor,
+        )
+    }
+}
+
+impl VirtioDeviceNew for Wl {
+    type Params = Params;
+    type Error = ();
+
+    fn new(params: Params) -> Result<Self, ()> {
+        let Params {
+            wayland_paths,
+            vm_socket,
+            resource_bridge,
+        } = params;
+
+        Ok(Self {
             kill_evt: None,
             worker_thread: None,
             wayland_paths,
diff --git a/src/wl.rs b/src/wl.rs
index be2ca2e..45dfbee 100644
--- a/src/wl.rs
+++ b/src/wl.rs
@@ -1,7 +1,8 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
 use devices::virtio::{
-    InterruptProxy, InterruptProxyEvent, RemotePciCapability, Request, Response, VirtioDevice, Wl,
+    InterruptProxy, InterruptProxyEvent, Params, RemotePciCapability, Request, Response,
+    VirtioDevice, VirtioDeviceNew, Wl,
 };
 use msg_socket::MsgSocket;
 use std::collections::BTreeMap;
@@ -47,7 +48,12 @@ fn main() {
     let mut wayland_paths = BTreeMap::new();
     wayland_paths.insert("".into(), "/run/user/1000/wayland-0".into());
 
-    let mut wl = Wl::new(wayland_paths, vm_socket, None).unwrap();
+    let mut wl = Wl::new(Params {
+        wayland_paths,
+        vm_socket,
+        resource_bridge: None,
+    })
+    .unwrap();
 
     loop {
         match msg_socket.recv() {