From c612facef385949daec552c2679491e71d2311e5 Mon Sep 17 00:00:00 2001 From: Alyssa Ross Date: Mon, 6 Jul 2020 16:17:26 +0000 Subject: crosvm: make wl add take socket as fd --- devices/src/virtio/wl.rs | 147 +++++++++++++++++++++++++++++++++++++++-------- src/main.rs | 7 ++- vm_control/src/lib.rs | 15 +---- 3 files changed, 130 insertions(+), 39 deletions(-) diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index 12f5012..7b93405 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -53,8 +53,9 @@ use data_model::VolatileMemoryError; use data_model::*; use msg_socket::{MsgError, MsgReceiver, MsgSender}; -use msg_socket2::de::VisitorWithFds; +use msg_socket2::de::{EnumAccessWithFds, VariantAccessWithFds, VisitorWithFds}; use msg_socket2::ser::SerializeAdapter; +use msg_socket2::Deserialize; #[cfg(feature = "wl-dmabuf")] use resources::GpuMemoryDesc; #[cfg(feature = "wl-dmabuf")] @@ -612,6 +613,12 @@ impl WlVfd { Ok(vfd) } + fn from_socket(socket: UnixStream) -> WlVfd { + let mut vfd = WlVfd::default(); + vfd.socket = Some(socket); + vfd + } + fn allocate(vm: VmRequester, size: u64) -> WlResult { let size_page_aligned = round_up_to_page_size(size as usize) as u64; let mut vfd_shm = SharedMemory::named("virtwl_alloc").map_err(WlError::NewAlloc)?; @@ -879,8 +886,79 @@ enum WlRecv { Hup, } +#[derive(Debug)] +enum WaylandSocket { + Listening(PathBuf), + NonListening(UnixStream), +} + +impl SerializeWithFds for WaylandSocket { + fn serialize(&self, serializer: S) -> Result { + use WaylandSocket::*; + match self { + Listening(path) => { + serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path) + } + NonListening(socket) => serializer.serialize_newtype_variant( + "WaylandSocket", + 1, + "NonListening", + &SerializeAdapter::new(socket), + ), + } + } + + fn serialize_fds<'fds, S: FdSerializer<'fds>>( + &'fds self, + serializer: S, + ) -> Result { + use WaylandSocket::*; + match self { + Listening(path) => { + serializer.serialize_newtype_variant("WaylandSocket", 0, "Listening", path) + } + NonListening(socket) => { + serializer.serialize_newtype_variant("WaylandSocket", 1, "NonListening", socket) + } + } + } +} + +impl<'de> DeserializeWithFds<'de> for WaylandSocket { + fn deserialize>(deserializer: D) -> Result { + struct Visitor; + + impl<'de> VisitorWithFds<'de> for Visitor { + type Value = WaylandSocket; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "enum WaylandSocket") + } + + fn visit_enum(self, data: A) -> Result + where + A: EnumAccessWithFds<'de>, + { + #[derive(Debug, Deserialize)] + enum Variant { + Listening, + NonListening, + } + + match data.variant()? { + (Variant::Listening, variant) => variant.newtype_variant(), + + (Variant::NonListening, variant) => variant.newtype_variant(), + } + } + } + + deserializer.deserialize_enum("WaylandSocket", &["Listening", "NonListening"], Visitor) + } +} + struct WlState { - wayland_paths: Map, PathBuf>, + wayland_sockets: Map, WaylandSocket>, vm: VmRequester, resource_bridge: Option, use_transition_flags: bool, @@ -895,13 +973,13 @@ struct WlState { impl WlState { fn new( - wayland_paths: Map, PathBuf>, + wayland_sockets: Map, WaylandSocket>, vm_socket: VmMemoryControlRequestSocket, use_transition_flags: bool, resource_bridge: Option, ) -> WlState { WlState { - wayland_paths, + wayland_sockets, vm: VmRequester::new(vm_socket), resource_bridge, poll_ctx: PollContext::new().expect("failed to create PollContext"), @@ -1035,12 +1113,18 @@ impl WlState { match self.vfds.entry(id) { Entry::Vacant(entry) => { - let vfd = entry.insert(WlVfd::connect( - &self - .wayland_paths - .get(name) - .ok_or(WlError::UnknownSocketName(name.to_vec()))?, - )?); + let vfd = + if let Some(WaylandSocket::Listening(path)) = self.wayland_sockets.get(name) { + WlVfd::connect(path)? + } else if let Some(WaylandSocket::NonListening(socket)) = + self.wayland_sockets.remove(name) + { + WlVfd::from_socket(socket) + } else { + return Err(WlError::UnknownSocketName(name.to_vec())); + }; + + let vfd = entry.insert(vfd); self.poll_ctx .add(vfd.poll_fd().unwrap(), id) .map_err(WlError::PollContextAdd)?; @@ -1056,16 +1140,17 @@ impl WlState { } } - fn add_path(&mut self, name: Vec, path: PathBuf) -> Result<(), Error> { + fn add_socket(&mut self, name: Vec, socket: UnixStream) -> Result<(), Error> { if name.len() > 32 { return Err(Error::new(libc::EINVAL)); } - if self.wayland_paths.contains_key(&name) { + if self.wayland_sockets.contains_key(&name) { return Err(Error::new(libc::EADDRINUSE)); } - self.wayland_paths.insert(name, path); + self.wayland_sockets + .insert(name, WaylandSocket::NonListening(socket)); Ok(()) } @@ -1413,7 +1498,7 @@ impl WlState { } } -pub struct Worker { +struct Worker { interrupt: Interrupt, mem: GuestMemory, in_queue: Queue, @@ -1423,12 +1508,12 @@ pub struct Worker { } impl Worker { - pub fn new( + fn new( mem: GuestMemory, interrupt: Interrupt, in_queue: Queue, out_queue: Queue, - wayland_paths: Map, PathBuf>, + wayland_sockets: Map, WaylandSocket>, vm_socket: VmMemoryControlRequestSocket, use_transition_flags: bool, resource_bridge: Option, @@ -1440,7 +1525,7 @@ impl Worker { in_queue, out_queue, state: WlState::new( - wayland_paths, + wayland_sockets, vm_socket, use_transition_flags, resource_bridge, @@ -1565,14 +1650,14 @@ impl Worker { } } Token::CommandSocket => { - let resp = match dbg!(self.control_socket.recv()) { - Ok(WlControlCommand::AddSocket { name, path }) => { + let resp: WlControlResult = match self.control_socket.recv() { + Ok(WlControlCommand::AddSocket { name, socket }) => { match unique_name(Cow::Owned(name), |name| !self .state .wayland_sockets .contains_key(name)) { - Some(name) => match self.state.add_path(name.clone(), socket) + Some(name) => match self.state.add_socket(name.clone(), socket) { Ok(()) => WlControlResult::SocketAdded(name), Err(e) => WlControlResult::Err(e), @@ -1653,7 +1738,7 @@ impl Worker { pub struct Wl { kill_evt: Option, worker_thread: Option>, - wayland_paths: Map, PathBuf>, + wayland_sockets: Option, WaylandSocket>>, vm_socket: Option, resource_bridge: Option, use_transition_flags: bool, @@ -1752,7 +1837,7 @@ impl<'de> DeserializeWithFds<'de> for Params { deserializer.deserialize_struct( "Params", - &["wayland_paths", "vm_socket", "resource_bridge"], + &["wayland_sockets", "vm_socket", "resource_bridge"], Visitor, ) } @@ -1770,10 +1855,15 @@ impl VirtioDeviceNew for Wl { control_socket, } = params; + let wayland_sockets = wayland_paths + .into_iter() + .map(|(n, path)| (n, WaylandSocket::Listening(path))) + .collect(); + Ok(Self { kill_evt: None, worker_thread: None, - wayland_paths, + wayland_sockets: Some(wayland_sockets), vm_socket: Some(vm_socket), resource_bridge, use_transition_flags: false, @@ -1799,6 +1889,13 @@ impl VirtioDevice for Wl { fn keep_fds(&self) -> Vec { let mut keep_fds = Vec::new(); + if let Some(ref wayland_sockets) = self.wayland_sockets { + for (_, socket) in wayland_sockets { + if let WaylandSocket::NonListening(socket) = socket { + keep_fds.push(socket.as_raw_fd()); + } + } + } if let Some(vm_socket) = &self.vm_socket { keep_fds.push(vm_socket.as_raw_fd()); } @@ -1853,7 +1950,7 @@ impl VirtioDevice for Wl { self.kill_evt = Some(self_kill_evt); if let Some(vm_socket) = self.vm_socket.take() { - let wayland_paths = self.wayland_paths.clone(); + let wayland_sockets = self.wayland_sockets.take().unwrap(); let use_transition_flags = self.use_transition_flags; let resource_bridge = self.resource_bridge.take(); let control_socket = self.control_socket.take().unwrap(); @@ -1867,7 +1964,7 @@ impl VirtioDevice for Wl { interrupt, queues.remove(0), queues.remove(0), - wayland_paths, + wayland_sockets, vm_socket, use_transition_flags, resource_bridge, diff --git a/src/main.rs b/src/main.rs index 7b47d3c..611dd29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use std::fs::{File, OpenOptions}; use std::io::{BufRead, BufReader}; use std::num::ParseIntError; use std::os::unix::io::{FromRawFd, RawFd}; +use std::os::unix::net::UnixStream; use std::path::{Path, PathBuf}; use std::string::String; use std::thread::sleep; @@ -2191,8 +2192,10 @@ fn wl_cmd(mut args: std::env::Args) -> std::result::Result<(), ()> { let request = match subcommand { "add" => { let name = args.next().unwrap().as_bytes().to_vec(); - let path = args.next().unwrap().into(); - VmRequest::WlCommand(WlControlCommand::AddSocket { name, path }) + // Safe because we're taking ownership of descriptor 3, and won't use it for anything + // else. + let socket = unsafe { UnixStream::from_raw_fd(3) }; + VmRequest::WlCommand(WlControlCommand::AddSocket { name, socket }) } _ => { error!("Unknown wl subcommand '{}'", subcommand); diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index b2e328c..f8ffe21 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -15,7 +15,7 @@ use std::fs::File; use std::io::{Seek, SeekFrom}; use std::mem::ManuallyDrop; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use std::path::PathBuf; +use std::os::unix::net::UnixStream; use libc::{EINVAL, EIO, ENODEV}; @@ -514,25 +514,16 @@ impl VmMsyncRequest { #[derive(MsgOnSocket, Debug)] pub enum WlControlCommand { - AddSocket { name: Vec, path: PathBuf }, + AddSocket { name: Vec, socket: UnixStream }, } #[derive(MsgOnSocket, Debug)] pub enum WlControlResult { Ready, - Ok, + SocketAdded(Vec), Err(SysError), } -impl From> for WlControlResult { - fn from(result: Result<()>) -> Self { - match result { - Ok(()) => Self::Ok, - Err(e) => Self::Err(e), - } - } -} - pub type BalloonControlRequestSocket = MsgSocket; pub type BalloonControlResponseSocket = MsgSocket; -- cgit 1.4.1