From 98d69a42870030ad533dd8eda5da817430c2b71c Mon Sep 17 00:00:00 2001 From: Alyssa Ross Date: Thu, 26 Mar 2020 11:54:48 +0000 Subject: send wl::Params over socket --- devices/src/virtio/controller.rs | 33 ++++++++---------- devices/src/virtio/wl.rs | 1 + msg_socket/src/lib.rs | 12 ++++++- msg_socket2/src/de.rs | 72 ++++++++++++++++++++++++++++++++++++++-- msg_socket2/src/ser.rs | 2 ++ msg_socket2/tests/option.rs | 12 +++++++ src/linux.rs | 10 +++--- src/wl.rs | 21 ++++-------- 8 files changed, 120 insertions(+), 43 deletions(-) create mode 100644 msg_socket2/tests/option.rs diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs index ba2543f..4930e44 100644 --- a/devices/src/virtio/controller.rs +++ b/devices/src/virtio/controller.rs @@ -28,29 +28,26 @@ //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket. -use std::collections::BTreeMap as Map; use std::fmt::{self, Formatter}; use std::os::unix::io::{AsRawFd, RawFd}; -use std::path::PathBuf; use std::sync::Arc; use std::thread; use msg_socket::{MsgReceiver, MsgSocket}; use msg_socket2::de::{EnumAccessWithFds, SeqAccessWithFds, VariantAccessWithFds, VisitorWithFds}; -use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds}; +use msg_socket2::ser::{SerializeAdapter, SerializeStructVariantFds, SerializeTupleVariantFds}; use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds}; use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer}; use serde::{Deserialize, Serialize}; use sys_util::net::UnixSeqpacket; use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory}; -use super::resource_bridge::*; -use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice}; +use super::{Interrupt, InterruptProxyEvent, Params, Queue, VirtioDevice}; use crate::{ pci::{PciAddress, PciBarConfiguration, PciCapability, PciCapabilityID}, MemoryParams, }; -use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket}; +use vm_control::MaybeOwnedFd; // As far as I can tell, these never change on the other side, so it's // fine to just copy them over. @@ -82,9 +79,7 @@ impl PciCapability for RemotePciCapability { #[derive(Debug)] pub enum Request { Create { - // wayland_paths: Map, - vm_socket: MaybeOwnedFd, - // resource_bridge: Option, + device_params: Params, memory_params: MemoryParams, }, @@ -130,11 +125,11 @@ impl SerializeWithFds for Request { match self { Create { - vm_socket: _, + device_params, memory_params, } => { let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?; - sv.serialize_field("vm_socket", &())?; + sv.serialize_field("device_params", &SerializeAdapter::new(device_params))?; sv.serialize_field("memory_params", memory_params)?; sv.end() } @@ -209,11 +204,11 @@ impl SerializeWithFds for Request { match self { Create { - vm_socket, + device_params, memory_params, } => { let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?; - sv.serialize_field("vm_socket", vm_socket)?; + sv.serialize_field("device_params", device_params)?; sv.serialize_field("memory_params", memory_params)?; sv.end() } @@ -336,7 +331,9 @@ impl<'de> DeserializeWithFds<'de> for Request { } Ok(Request::Create { - vm_socket: seq.next_element()?.ok_or_else(|| too_short(0))?, + device_params: seq + .next_element()? + .ok_or_else(|| too_short(0))?, memory_params: seq .next_element()? .ok_or_else(|| too_short(1))?, @@ -685,16 +682,12 @@ pub struct Controller { impl Controller { pub fn create( - wayland_paths: Map, - vm_socket: VmMemoryControlRequestSocket, - resource_bridge: Option, + device_params: Params, memory_params: MemoryParams, socket: Socket, ) -> Result { socket.send(Request::Create { - // wayland_paths, - vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket), - // resource_bridge, + device_params, memory_params, })?; diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index 5a3505b..6671693 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -1547,6 +1547,7 @@ use std::fmt::Formatter; use super::VirtioDeviceNew; +#[derive(Debug)] pub struct Params { pub wayland_paths: Map, pub vm_socket: VmMemoryControlRequestSocket, diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index 5dedac8..f9b1ca0 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -4,6 +4,7 @@ mod msg_on_socket; +use std::fmt::{self, Debug, Formatter}; use std::io::{IoSlice, Result}; use std::marker::PhantomData; use std::os::unix::prelude::*; @@ -34,7 +35,7 @@ pub fn pair( } /// Bidirection sock that support both send and recv. -#[derive(SerializeWithFds, DeserializeWithFds)] +#[derive(DeserializeWithFds, SerializeWithFds)] #[msg_socket2(strategy = "AsRawFd")] pub struct MsgSocket { sock: UnixSeqpacket, @@ -42,6 +43,15 @@ pub struct MsgSocket { _o: PhantomData, } +// Implement Debug manually because the derivable implementation only +// works when I and O are Debug, even though they're only used as +// PhantomData type parameters. +impl Debug for MsgSocket { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "MsgSocket {{ sock: {:?}, .. }}", self.sock) + } +} + impl MsgSocket { // Create a new MsgSocket. pub fn new(s: UnixSeqpacket) -> MsgSocket { diff --git a/msg_socket2/src/de.rs b/msg_socket2/src/de.rs index fa6582a..67073c0 100644 --- a/msg_socket2/src/de.rs +++ b/msg_socket2/src/de.rs @@ -17,7 +17,7 @@ use std::fmt::{self, Formatter}; use std::marker::PhantomData; pub use serde::de::{ - Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, SeqAccess, StdError, + Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, MapAccess, SeqAccess, StdError, VariantAccess, Visitor, }; @@ -352,7 +352,7 @@ pub trait MapAccessWithFds<'de>: Sized { /// Like `MapAccess::next_value_seed`, but `seed` is a /// `DeserializeWithFdsSeed` instead of a `DeserializeSeed`. - fn next_value_seed(&mut self, seed: V) -> Result, Self::Error> + fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeWithFdsSeed<'de>; @@ -362,7 +362,15 @@ pub trait MapAccessWithFds<'de>: Sized { &mut self, kseed: K, vseed: V, - ) -> Result, Self::Error>; + ) -> Result, Self::Error> { + match self.next_key_seed(kseed)? { + Some(key) => { + let value = self.next_value_seed(vseed)?; + Ok(Some((key, value))) + } + None => Ok(None), + } + } /// Like `MapAccess::next_key`, but returns a `DeserializeWithFds` /// instead of a `Deserialize`. @@ -400,6 +408,46 @@ pub trait MapAccessWithFds<'de>: Sized { fn invite>(self, visitor: V) -> Result; } +impl<'de, 'fds, A, F> MapAccessWithFds<'de> for WithFds<'fds, A, F> +where + A: MapAccess<'de>, + F: Iterator, +{ + type Error = A::Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeWithFdsSeed<'de>, + { + let wrapper = WithFds { + inner: seed, + fds: self.fds, + }; + + self.inner.next_key_seed(wrapper) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeWithFdsSeed<'de>, + { + let wrapper = WithFds { + inner: seed, + fds: self.fds, + }; + + self.inner.next_value_seed(wrapper) + } + + fn size_hint(&self) -> Option { + self.inner.size_hint() + } + + fn invite>(self, visitor: V) -> Result { + visitor.visit_map(self.inner) + } +} + /// Like `EnumAccess`, but variants provide access to file descriptors /// as well as data. pub trait EnumAccessWithFds<'de>: Sized { @@ -578,6 +626,17 @@ where self.inner.expecting(f) } + fn visit_none(self) -> Result { + self.inner.visit_none() + } + + fn visit_some>(self, deserializer: D) -> Result { + self.inner.visit_some(WithFds { + inner: deserializer, + fds: self.fds, + }) + } + fn visit_seq>(self, data: A) -> Result { self.inner.visit_seq(WithFds { inner: data, @@ -585,6 +644,13 @@ where }) } + fn visit_map>(self, map: A) -> Result { + self.inner.visit_map(WithFds { + inner: map, + fds: self.fds, + }) + } + fn visit_enum>(self, data: A) -> Result { self.inner.visit_enum(WithFds { inner: data, diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs index 1391d12..8be85aa 100644 --- a/msg_socket2/src/ser.rs +++ b/msg_socket2/src/ser.rs @@ -555,6 +555,8 @@ macro_rules! fd_impl { }; } +serialize_impl!(()); + serialize_impl!(u8); serialize_impl!(u16); serialize_impl!(u32); diff --git a/msg_socket2/tests/option.rs b/msg_socket2/tests/option.rs new file mode 100644 index 0000000..130dd7d --- /dev/null +++ b/msg_socket2/tests/option.rs @@ -0,0 +1,12 @@ +use msg_socket2::Socket; +use sys_util::net::UnixSeqpacket; + +#[test] +fn option() { + let (f1, f2) = UnixSeqpacket::pair().unwrap(); + let s1: Socket<_, ()> = Socket::new(f1); + let s2: Socket<(), Option> = Socket::new(f2); + + s1.send(Some("hello world".to_string())).unwrap(); + assert_eq!(s2.recv().unwrap(), Some("hello world".to_string())); +} diff --git a/src/linux.rs b/src/linux.rs index 7bd4679..a416a3e 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -31,7 +31,7 @@ use acpi_tables::sdt::SDT; #[cfg(feature = "gpu")] use devices::virtio::EventDevice; -use devices::virtio::{self, Console, VirtioDevice}; +use devices::virtio::{self, Console, Params, VirtioDevice}; use devices::{ self, Ac97Backend, Ac97Dev, Bus, HostBackendDeviceProvider, PciDevice, VfioContainer, VfioDevice, VfioPciDevice, VirtioPciDevice, XhciController, @@ -772,9 +772,11 @@ fn create_wayland_device( let seq_socket = UnixSeqpacket::connect(&path).expect("connect failed"); let msg_socket = msg_socket2::Socket::new(seq_socket); let dev = virtio::Controller::create( - cfg.wayland_socket_paths.clone(), - socket, - resource_bridge, + Params { + wayland_paths: cfg.wayland_socket_paths.clone(), + vm_socket: socket, + resource_bridge, + }, memory_params, msg_socket, ) diff --git a/src/wl.rs b/src/wl.rs index 45dfbee..8ae8856 100644 --- a/src/wl.rs +++ b/src/wl.rs @@ -1,11 +1,10 @@ // SPDX-License-Identifier: BSD-3-Clause use devices::virtio::{ - InterruptProxy, InterruptProxyEvent, Params, RemotePciCapability, Request, Response, - VirtioDevice, VirtioDeviceNew, Wl, + InterruptProxy, InterruptProxyEvent, RemotePciCapability, Request, Response, VirtioDevice, + VirtioDeviceNew, Wl, }; use msg_socket::MsgSocket; -use std::collections::BTreeMap; use std::fs::remove_file; use sys_util::{error, net::UnixSeqpacketListener, GuestMemory}; @@ -30,11 +29,11 @@ fn main() { let conn = server.accept().expect("accept failed"); let msg_socket: Socket = msg_socket2::Socket::new(conn); - let (vm_socket, memory_params) = match msg_socket.recv() { + let (device_params, memory_params) = match msg_socket.recv() { Ok(Request::Create { - vm_socket, + device_params, memory_params, - }) => (MsgSocket::new(vm_socket.owned()), memory_params), + }) => (device_params, memory_params), Ok(msg) => { panic!("received unexpected message: {:?}", msg); @@ -45,15 +44,7 @@ fn main() { } }; - let mut wayland_paths = BTreeMap::new(); - wayland_paths.insert("".into(), "/run/user/1000/wayland-0".into()); - - let mut wl = Wl::new(Params { - wayland_paths, - vm_socket, - resource_bridge: None, - }) - .unwrap(); + let mut wl = Wl::new(device_params).unwrap(); loop { match msg_socket.recv() { -- cgit 1.4.1