diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-06-14 20:18:48 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2020-06-15 09:37:07 +0000 |
commit | 4d2e22e374e8ac93be227f0357efd2c0d7a9f699 (patch) | |
tree | 738ba443359adaf7311abf2fdddb750263d18587 | |
parent | 8214c4c64fbdbf6ae84634bb822a90959271cad5 (diff) | |
download | crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar.gz crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar.bz2 crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar.lz crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar.xz crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.tar.zst crosvm-4d2e22e374e8ac93be227f0357efd2c0d7a9f699.zip |
switch from poly_msg_socket to msg_socket2
-rw-r--r-- | Cargo.lock | 13 | ||||
-rw-r--r-- | Cargo.toml | 9 | ||||
-rw-r--r-- | devices/Cargo.toml | 1 | ||||
-rw-r--r-- | devices/src/lib.rs | 3 | ||||
-rw-r--r-- | devices/src/virtio/controller.rs | 851 | ||||
-rw-r--r-- | devices/src/virtio/queue.rs | 4 | ||||
-rw-r--r-- | msg_socket2/src/error.rs | 12 | ||||
-rw-r--r-- | msg_socket2/src/socket.rs | 13 | ||||
-rw-r--r-- | msg_socket2/tests/round_trip.rs | 8 | ||||
-rw-r--r-- | src/linux.rs | 5 | ||||
-rw-r--r-- | src/wl.rs | 66 | ||||
-rw-r--r-- | sys_util/Cargo.toml | 1 | ||||
-rw-r--r-- | sys_util/src/guest_address.rs | 4 |
13 files changed, 854 insertions, 136 deletions
diff --git a/Cargo.lock b/Cargo.lock index 3b175fe..5b0f218 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -146,9 +146,9 @@ dependencies = [ "libcras 0.1.0", "minijail-sys 0.0.11", "msg_socket 0.1.0", + "msg_socket2 0.1.0", "net_util 0.1.0", "p9 0.1.0", - "poly_msg_socket 0.1.0", "protobuf 2.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "protos 0.1.0", "rand_ish 0.1.0", @@ -205,6 +205,7 @@ dependencies = [ "linux_input_sys 0.1.0", "msg_on_socket_derive 0.1.0", "msg_socket 0.1.0", + "msg_socket2 0.1.0", "net_sys 0.1.0", "net_util 0.1.0", "p9 0.1.0", @@ -499,6 +500,15 @@ dependencies = [ ] [[package]] +name = "msg_socket2" +version = "0.1.0" +dependencies = [ + "bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)", + "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", + "sys_util 0.1.0", +] + +[[package]] name = "net_sys" version = "0.1.0" dependencies = [ @@ -730,6 +740,7 @@ dependencies = [ "data_model 0.1.0", "libc 0.2.44 (registry+https://github.com/rust-lang/crates.io-index)", "poll_token_derive 0.1.0", + "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "sync 0.1.0", "syscall_defines 0.1.0", "tempfile 3.0.7", diff --git a/Cargo.toml b/Cargo.toml index 3123c3d..baf401b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,12 @@ panic = 'abort' overflow-checks = true [workspace] -members = ["qcow_utils"] +members = ["qcow_utils", + +# TEMP +# "data_socket", + +] exclude = [ "assertions", "async_core", @@ -68,9 +73,9 @@ libc = "0.2.44" libcras = "*" minijail-sys = "*" # provided by ebuild msg_socket = { path = "msg_socket" } +msg_socket2 = { path = "msg_socket2" } net_util = { path = "net_util" } p9 = { path = "p9" } -poly_msg_socket = { path = "poly_msg_socket" } protobuf = { version = "2.3", optional = true } protos = { path = "protos", optional = true } rand_ish = { path = "rand_ish" } diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 939364a..830cb86 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -34,6 +34,7 @@ libvda = { version = "*", optional = true } linux_input_sys = { path = "../linux_input_sys" } msg_on_socket_derive = { path = "../msg_socket/msg_on_socket_derive" } msg_socket = { path = "../msg_socket" } +msg_socket2 = { path = "../msg_socket2" } net_sys = { path = "../net_sys" } net_util = { path = "../net_util" } p9 = { path = "../p9" } diff --git a/devices/src/lib.rs b/devices/src/lib.rs index 9d39fbd..febb18f 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -50,8 +50,9 @@ pub use self::vfio::{VfioContainer, VfioDevice}; pub use self::virtio::VirtioPciDevice; use msg_socket::MsgOnSocket; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Copy, Debug, MsgOnSocket)] +#[derive(Clone, Copy, Debug, MsgOnSocket, Serialize, Deserialize)] pub struct MemoryParams { /// Physical memory size in bytes for the VM. pub size: u64, diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs index 07f5cbd..9b0784c 100644 --- a/devices/src/virtio/controller.rs +++ b/devices/src/virtio/controller.rs @@ -28,25 +28,31 @@ //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket. +use std::os::unix::prelude::*; + use std::collections::BTreeMap as Map; +use std::fmt::{self, Formatter}; use std::os::unix::io::{AsRawFd, RawFd}; use std::path::PathBuf; use std::sync::Arc; use std::thread; +use msg_socket::{MsgReceiver, MsgSocket}; +use msg_socket2::{DeserializeWithFds, DeserializerWithFds, SerializeWithFds, SerializerWithFds}; +use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess}; +use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer}; +use serde::{Deserialize, Serialize}; +use sys_util::net::UnixSeqpacket; +use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory}; + use super::resource_bridge::*; -use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice, TYPE_WL, VIRTIO_F_VERSION_1}; +use super::{Interrupt, InterruptProxyEvent, Queue, VirtioDevice}; use crate::{ pci::{PciAddress, PciBarConfiguration, PciCapability, PciCapabilityID}, MemoryParams, }; use vm_control::{MaybeOwnedFd, VmMemoryControlRequestSocket}; -use msg_socket::{MsgOnSocket, MsgReceiver, MsgSocket}; -use serde::{Deserialize, Serialize}; -use sys_util::net::UnixSeqpacket; -use sys_util::{error, EventFd, GuestMemory, PollContext, PollToken, SharedMemory}; - // As far as I can tell, these never change on the other side, so it's // fine to just copy them over. #[derive(Clone, Debug, Deserialize, Serialize)] @@ -74,8 +80,8 @@ impl PciCapability for RemotePciCapability { } } -#[derive(Debug, MsgOnSocket)] -pub enum MsgOnSocketRequest { +#[derive(Debug)] +pub enum Request { Create { // wayland_paths: Map<String, PathBuf>, vm_socket: MaybeOwnedFd<UnixSeqpacket>, @@ -83,11 +89,24 @@ pub enum MsgOnSocketRequest { memory_params: MemoryParams, }, + DebugLabel, + DeviceType, + QueueMaxSizes, + Features, AckFeatures(u64), + ReadConfig { + offset: u64, + len: usize, + }, + WriteConfig { + offset: u64, + data: Vec<u8>, + }, + Activate { shm: MaybeOwnedFd<SharedMemory>, interrupt: MaybeOwnedFd<UnixSeqpacket>, @@ -99,75 +118,748 @@ pub enum MsgOnSocketRequest { }, Reset, + + GetDeviceBars(PciAddress), + GetDeviceCaps, + Kill, } -#[derive(Debug, Serialize, Deserialize)] -pub enum BincodeRequest { - DebugLabel, +impl SerializeWithFds for Request { + fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + use Request::*; - QueueMaxSizes, + match self { + Create { + vm_socket, + memory_params, + } => { + let mut sv = serializer + .serializer + .serialize_struct_variant("Request", 0, "Create", 1)?; - ReadConfig { offset: u64, len: usize }, - WriteConfig { offset: u64, data: Vec<u8> }, + sv.skip_field("vm_socket")?; + serializer.fds.push(vm_socket.as_raw_fd()); - GetDeviceBars(PciAddress), - GetDeviceCaps, -} + sv.serialize_field("memory_params", memory_params)?; + + sv.end() + } + + DebugLabel => serializer + .serializer + .serialize_unit_variant("Request", 1, "DebugLabel"), + + DeviceType => serializer + .serializer + .serialize_unit_variant("Request", 2, "DeviceType"), + QueueMaxSizes => { + serializer + .serializer + .serialize_unit_variant("Request", 3, "QueueMaxSizes") + } + + Features => serializer + .serializer + .serialize_unit_variant("Request", 4, "Features"), + + AckFeatures(features) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Request", + 5, + "AckFeatures", + 1, + )?; + tv.serialize_field(features)?; + tv.end() + } + + ReadConfig { offset, len } => { + let mut sv = serializer.serializer.serialize_struct_variant( + "Request", + 6, + "ReadConfig", + 2, + )?; + sv.serialize_field("offset", offset)?; + sv.serialize_field("len", len)?; + sv.end() + } + + WriteConfig { offset, data } => { + let mut sv = serializer.serializer.serialize_struct_variant( + "Request", + 7, + "WriteConfig", + 2, + )?; + sv.serialize_field("offset", offset)?; + sv.serialize_field("data", data)?; + sv.end() + } + + Activate { + shm, + interrupt, + interrupt_resample_evt, + in_queue, + out_queue, + in_queue_evt, + out_queue_evt, + } => { + let mut sv = serializer + .serializer + .serialize_struct_variant("Request", 8, "Activate", 2)?; -pub type Request = poly_msg_socket::Value<MsgOnSocketRequest, BincodeRequest>; + sv.skip_field("shm")?; + serializer.fds.push(shm.as_raw_fd()); -impl From<MsgOnSocketRequest> for Request { - fn from(request: MsgOnSocketRequest) -> Self { - Self::MsgOnSocket(request) + sv.skip_field("interrupt")?; + serializer.fds.push(interrupt.as_raw_fd()); + + sv.skip_field("interrupt_resample_evt")?; + serializer.fds.push(interrupt_resample_evt.as_raw_fd()); + + sv.serialize_field("in_queue", in_queue)?; + sv.serialize_field("out_queue", out_queue)?; + + sv.skip_field("in_queue_evt")?; + serializer.fds.push(in_queue_evt.as_raw_fd()); + + sv.skip_field("out_queue_evt")?; + serializer.fds.push(out_queue_evt.as_raw_fd()); + + sv.end() + } + + Reset => serializer + .serializer + .serialize_unit_variant("Request", 9, "Reset"), + + GetDeviceBars(address) => { + let mut sv = serializer.serializer.serialize_struct_variant( + "Request", + 10, + "GetDeviceBars", + 1, + )?; + sv.serialize_field("address", address)?; + sv.end() + } + + GetDeviceCaps => { + serializer + .serializer + .serialize_unit_variant("Request", 11, "GetDeviceCaps") + } + + Kill => serializer + .serializer + .serialize_unit_variant("Request", 12, "Kill"), + } } } -impl From<BincodeRequest> for Request { - fn from(request: BincodeRequest) -> Self { - Self::Bincode(request) +impl<'de> DeserializeWithFds<'de> for Request { + fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error> + where + I: Iterator<Item = RawFd>, + De: Deserializer<'de>, + { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "enum Request") + } + + fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> { + #[derive(Debug, Deserialize)] + enum Variant { + Create, + DebugLabel, + DeviceType, + QueueMaxSizes, + Features, + AckFeatures, + ReadConfig, + WriteConfig, + Activate, + Reset, + GetDeviceBars, + GetDeviceCaps, + Kill, + } + + match data.variant()? { + (Variant::Create, variant) => { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct variant Request::Create") + } + + fn visit_seq<A: SeqAccess<'de>>( + self, + mut seq: A, + ) -> Result<Request, A::Error> { + use serde::de::Error; + + Ok(Request::Create { + vm_socket: match self.fds.next() { + Some(vm_socket) => MaybeOwnedFd::Owned(unsafe { + UnixSeqpacket::from_raw_fd(vm_socket) + }), + None => { + return Err(Error::invalid_length( + 0, + &"struct variant Request::Create with 2 elements", + )) + } + }, + + memory_params: match seq.next_element()? { + Some(memory_params) => memory_params, + None => { + return Err(Error::invalid_length( + 1, + &"struct variant Request::Create with 2 elements", + )) + } + }, + }) + } + } + + let visitor = Visitor { fds: self.fds }; + variant.struct_variant(&["memory_params"], visitor) + } + + (Variant::DebugLabel, variant) => { + variant.unit_variant()?; + Ok(Request::DebugLabel) + } + + (Variant::DeviceType, variant) => { + variant.unit_variant()?; + Ok(Request::DeviceType) + } + + (Variant::QueueMaxSizes, variant) => { + variant.unit_variant()?; + Ok(Request::QueueMaxSizes) + } + + (Variant::Features, variant) => { + variant.unit_variant()?; + Ok(Request::Features) + } + + (Variant::AckFeatures, variant) => { + Ok(Request::AckFeatures(variant.newtype_variant()?)) + } + + (Variant::ReadConfig, variant) => { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct variant Request::ReadConfig") + } + + fn visit_seq<A: SeqAccess<'de>>( + self, + mut seq: A, + ) -> Result<Request, A::Error> { + use serde::de::Error; + + Ok(Request::ReadConfig { + offset: match seq.next_element()? { + Some(offset) => offset, + None => return Err(Error::invalid_length( + 0, + &"struct variant Request::ReadConfig with 2 elements", + )), + }, + + len: match seq.next_element()? { + Some(len) => len, + None => return Err(Error::invalid_length( + 1, + &"struct variant Request::ReadConfig with 2 elements", + )), + }, + }) + } + } + + let visitor = Visitor { fds: self.fds }; + variant.struct_variant(&["offset", "len"], visitor) + } + + (Variant::WriteConfig, variant) => { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct variant Request::WriteConfig") + } + + fn visit_seq<A: SeqAccess<'de>>( + self, + mut seq: A, + ) -> Result<Request, A::Error> { + use serde::de::Error; + + Ok(Request::WriteConfig { + offset: match seq.next_element()? { + Some(offset) => offset, + None => return Err(Error::invalid_length( + 0, + &"struct variant Request::WriteConfig with 2 elements", + )), + }, + + data: match seq.next_element()? { + Some(data) => data, + None => return Err(Error::invalid_length( + 1, + &"struct variant Request::WriteConfig with 2 elements", + )), + }, + }) + } + } + + let visitor = Visitor { fds: self.fds }; + variant.struct_variant(&["offset", "data"], visitor) + } + + (Variant::Activate, variant) => { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct variant Request::Activate") + } + + fn visit_seq<A: SeqAccess<'de>>( + self, + mut seq: A, + ) -> Result<Request, A::Error> { + use serde::de::Error; + + Ok(Request::Activate { + shm: match self.fds.next() { + Some(shm) => MaybeOwnedFd::Owned(unsafe { + FromRawFd::from_raw_fd(shm) + }), + None => { + return Err(Error::invalid_length( + 0, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + interrupt: match self.fds.next() { + Some(interrupt) => MaybeOwnedFd::Owned(unsafe { + UnixSeqpacket::from_raw_fd(interrupt) + }), + None => { + return Err(Error::invalid_length( + 1, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + interrupt_resample_evt: match self.fds.next() { + Some(interrupt_resample_evt) => { + MaybeOwnedFd::Owned(unsafe { + EventFd::from_raw_fd(interrupt_resample_evt) + }) + } + None => { + return Err(Error::invalid_length( + 2, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + in_queue: match seq.next_element()? { + Some(in_queue) => in_queue, + None => { + return Err(Error::invalid_length( + 3, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + out_queue: match seq.next_element()? { + Some(out_queue) => out_queue, + None => { + return Err(Error::invalid_length( + 4, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + in_queue_evt: match self.fds.next() { + Some(in_queue_evt) => MaybeOwnedFd::Owned(unsafe { + EventFd::from_raw_fd(in_queue_evt) + }), + None => { + return Err(Error::invalid_length( + 5, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + + out_queue_evt: match self.fds.next() { + Some(out_queue_evt) => MaybeOwnedFd::Owned(unsafe { + EventFd::from_raw_fd(out_queue_evt) + }), + None => { + return Err(Error::invalid_length( + 6, + &"struct variant Request::Activate with 7 elements", + )) + } + }, + }) + } + } + + let visitor = Visitor { fds: self.fds }; + variant.struct_variant(&["in_queue", "out_queue"], visitor) + } + + (Variant::Reset, variant) => { + variant.unit_variant()?; + Ok(Request::Reset) + } + + (Variant::GetDeviceBars, variant) => { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Request; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "struct variant Request::GetDeviceBars") + } + + fn visit_seq<A: SeqAccess<'de>>( + self, + mut seq: A, + ) -> Result<Request, A::Error> { + use serde::de::Error; + + Ok(Request::GetDeviceBars(match seq.next_element()? { + Some(address) => address, + None => { + return Err(Error::invalid_length( + 0, + &"struct variant Request::GetDeviceBars with 1 element", + )) + } + })) + } + } + + let visitor = Visitor { fds: self.fds }; + variant.struct_variant(&["bus", "dev"], visitor) + } + + (Variant::GetDeviceCaps, variant) => { + variant.unit_variant()?; + Ok(Request::GetDeviceCaps) + } + + (Variant::Kill, variant) => { + variant.unit_variant()?; + Ok(Request::Kill) + } + } + } + } + + let DeserializerWithFds { + mut fds, + deserializer, + } = deserializer; + let visitor = Visitor { fds: &mut fds }; + deserializer.deserialize_enum( + "Request", + &[ + "Create", + "DebugLabel", + "DeviceType", + "QueueMaxSizes", + "Features", + "AckFeatures", + "ReadConfig", + "WriteConfig", + "Activate", + "Reset", + "GetDeviceBars", + "GetDeviceCaps", + "Kill", + ], + visitor, + ) } } -#[derive(Debug, MsgOnSocket)] -pub enum MsgOnSocketResponse { +#[derive(Debug)] +pub enum Response { + DebugLabel(String), DeviceType(u32), + QueueMaxSizes(Vec<u16>), Features(u64), + ReadConfig(Vec<u8>), Reset(bool), + GetDeviceBars(Vec<PciBarConfiguration>), + GetDeviceCaps(Vec<RemotePciCapability>), Kill, } -#[derive(Debug, Deserialize, Serialize)] -pub enum BincodeResponse { - DebugLabel(String), +impl SerializeWithFds for Response { + fn serialize<S>(&self, serializer: SerializerWithFds<S>) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + use Response::*; + + match self { + DebugLabel(label) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 0, + "DebugLabel", + 1, + )?; + tv.serialize_field(label)?; + tv.end() + } - QueueMaxSizes(Vec<u16>), + DeviceType(device_type) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 1, + "DeviceType", + 1, + )?; + tv.serialize_field(device_type)?; + tv.end() + } - ReadConfig(Vec<u8>), + QueueMaxSizes(sizes) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 2, + "QueueMaxSizes", + 1, + )?; + tv.serialize_field(sizes)?; + tv.end() + } - GetDeviceBars(Vec<PciBarConfiguration>), - GetDeviceCaps(Vec<RemotePciCapability>), -} + Features(features) => { + let mut tv = serializer + .serializer + .serialize_tuple_variant("Response", 3, "Features", 1)?; + tv.serialize_field(features)?; + tv.end() + } + + ReadConfig(config) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 4, + "ReadConfig", + 1, + )?; + tv.serialize_field(config)?; + tv.end() + } + + Reset(success) => { + let mut tv = serializer + .serializer + .serialize_tuple_variant("Response", 5, "Reset", 1)?; + tv.serialize_field(success)?; + tv.end() + } -pub type Response = poly_msg_socket::Value<MsgOnSocketResponse, BincodeResponse>; + GetDeviceBars(bars) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 6, + "GetDeviceBars", + 1, + )?; + tv.serialize_field(bars)?; + tv.end() + } + + GetDeviceCaps(caps) => { + let mut tv = serializer.serializer.serialize_tuple_variant( + "Response", + 7, + "GetDeviceCaps", + 1, + )?; + tv.serialize_field(caps)?; + tv.end() + } -impl From<MsgOnSocketResponse> for Response { - fn from(response: MsgOnSocketResponse) -> Self { - Self::MsgOnSocket(response) + Kill => serializer + .serializer + .serialize_unit_variant("Response", 8, "Kill"), + } } } -impl From<BincodeResponse> for Response { - fn from(response: BincodeResponse) -> Self { - Self::Bincode(response) +impl<'de> DeserializeWithFds<'de> for Response { + fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error> + where + I: Iterator<Item = RawFd>, + De: Deserializer<'de>, + { + struct Visitor<'iter, Iter> { + fds: &'iter mut Iter, + } + + impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> + where + Iter: Iterator<Item = RawFd>, + { + type Value = Response; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "enum Response") + } + + fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> { + #[derive(Debug, Deserialize)] + enum Variant { + DebugLabel, + DeviceType, + QueueMaxSizes, + Features, + ReadConfig, + Reset, + GetDeviceBars, + GetDeviceCaps, + Kill, + } + + match data.variant()? { + (Variant::DebugLabel, variant) => { + Ok(Response::DebugLabel(variant.newtype_variant()?)) + } + (Variant::DeviceType, variant) => { + Ok(Response::DeviceType(variant.newtype_variant()?)) + } + (Variant::QueueMaxSizes, variant) => { + Ok(Response::QueueMaxSizes(variant.newtype_variant()?)) + } + (Variant::Features, variant) => { + Ok(Response::Features(variant.newtype_variant()?)) + } + (Variant::ReadConfig, variant) => { + Ok(Response::ReadConfig(variant.newtype_variant()?)) + } + (Variant::Reset, variant) => Ok(Response::Reset(variant.newtype_variant()?)), + (Variant::GetDeviceBars, variant) => { + Ok(Response::GetDeviceBars(variant.newtype_variant()?)) + } + (Variant::GetDeviceCaps, variant) => { + Ok(Response::GetDeviceCaps(variant.newtype_variant()?)) + } + + (Variant::Kill, variant) => { + variant.unit_variant()?; + Ok(Response::Kill) + } + } + } + } + + let DeserializerWithFds { + mut fds, + deserializer, + } = deserializer; + let visitor = Visitor { fds: &mut fds }; + deserializer.deserialize_enum( + "Response", + &[ + "DebugLabel", + "DeviceType", + "QueueMaxSizes", + "Features", + "ReadConfig", + "Reset", + "GetDeviceBars", + "GetDeviceCaps", + "Kill", + ], + visitor, + ) } } -use poly_msg_socket::PolyMsgSocket; -type Socket = - PolyMsgSocket<MsgOnSocketRequest, MsgOnSocketResponse, BincodeRequest, BincodeResponse>; - -const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01; +type Socket = msg_socket2::Socket<Request, Response>; // TODO: support arbitrary number of queues const QUEUE_SIZE: u16 = 16; @@ -196,9 +888,8 @@ impl Worker { } fn handle_response(&mut self) { - use poly_msg_socket::Value::*; match self.device_socket.recv() { - Ok(MsgOnSocket(MsgOnSocketResponse::Kill)) => { + Ok(Response::Kill) => { self.shutdown = true; } @@ -224,9 +915,7 @@ impl Worker { } fn kill(&self) { - if let Err(e) = self.device_socket.send(poly_msg_socket::Value::MsgOnSocket( - MsgOnSocketRequest::Kill, - )) { + if let Err(e) = self.device_socket.send(Request::Kill) { error!("failed to send Kill: {}", e); } } @@ -282,8 +971,8 @@ impl Controller { resource_bridge: Option<ResourceRequestSocket>, memory_params: MemoryParams, socket: Socket, - ) -> Result<Controller, poly_msg_socket::Error> { - socket.send(MsgOnSocketRequest::Create { + ) -> Result<Controller, msg_socket2::Error> { + socket.send(Request::Create { // wayland_paths, vm_socket: MaybeOwnedFd::new_borrowed(&vm_socket), // resource_bridge, @@ -313,12 +1002,12 @@ impl Drop for Controller { impl VirtioDevice for Controller { fn debug_label(&self) -> String { - if let Err(e) = self.socket.send(BincodeRequest::DebugLabel) { + if let Err(e) = self.socket.send(Request::DebugLabel) { return format!("remote virtio (unknown type; {})", e); } - let label = match self.socket.recv_bincode() { - Ok(BincodeResponse::DebugLabel(label)) => label, + let label = match self.socket.recv() { + Ok(Response::DebugLabel(label)) => label, response => panic!("bad response to DebugLabel: {:?}", response), }; @@ -338,12 +1027,12 @@ impl VirtioDevice for Controller { } fn device_type(&self) -> u32 { - if let Err(e) = self.socket.send(MsgOnSocketRequest::DeviceType) { + if let Err(e) = self.socket.send(Request::DeviceType) { panic!("failed to send DeviceType: {}", e); } - match self.socket.recv_msg_on_socket() { - Ok(MsgOnSocketResponse::DeviceType(device_type)) => device_type, + match self.socket.recv() { + Ok(Response::DeviceType(device_type)) => device_type, response => { panic!("bad response to Reset: {:?}", response); } @@ -351,25 +1040,25 @@ impl VirtioDevice for Controller { } fn queue_max_sizes(&self) -> Vec<u16> { - if let Err(e) = self.socket.send(BincodeRequest::QueueMaxSizes) { + if let Err(e) = self.socket.send(Request::QueueMaxSizes) { panic!("failed to send QueueMaxSizes: {}", e); } - match self.socket.recv_bincode() { - Ok(BincodeResponse::QueueMaxSizes(sizes)) => sizes, + match self.socket.recv() { + Ok(Response::QueueMaxSizes(sizes)) => sizes, response => { - panic!("bad response to Reset: {:?}", response); + panic!("bad response to QueueMaxSizes: {:?}", response); } } } fn features(&self) -> u64 { - if let Err(e) = self.socket.send(MsgOnSocketRequest::Features) { + if let Err(e) = self.socket.send(Request::Features) { panic!("failed to send Features: {}", e); } - match self.socket.recv_msg_on_socket() { - Ok(MsgOnSocketResponse::Features(features)) => features, + match self.socket.recv() { + Ok(Response::Features(features)) => features, response => { panic!("bad response to Reset: {:?}", response); } @@ -377,7 +1066,7 @@ impl VirtioDevice for Controller { } fn ack_features(&mut self, value: u64) { - if let Err(e) = self.socket.send(MsgOnSocketRequest::AckFeatures(value)) { + if let Err(e) = self.socket.send(Request::AckFeatures(value)) { panic!("failed to send AckFeatures: {}", e); } } @@ -385,12 +1074,12 @@ impl VirtioDevice for Controller { fn read_config(&self, offset: u64, data: &mut [u8]) { let len = data.len(); - if let Err(e) = self.socket.send(BincodeRequest::ReadConfig { offset, len }) { + if let Err(e) = self.socket.send(Request::ReadConfig { offset, len }) { panic!("failed to send ReadConfig: {}", e); } - match self.socket.recv_bincode() { - Ok(BincodeResponse::ReadConfig(response)) => { + match self.socket.recv() { + Ok(Response::ReadConfig(response)) => { data.copy_from_slice(&response[..len]); // TODO: test no panic } response => panic!("bad response to ReadConfig: {:?}", response), @@ -398,7 +1087,7 @@ impl VirtioDevice for Controller { } fn write_config(&mut self, offset: u64, data: &[u8]) { - if let Err(e) = self.socket.send(BincodeRequest::WriteConfig { + if let Err(e) = self.socket.send(Request::WriteConfig { offset, data: data.to_vec(), }) { @@ -431,7 +1120,7 @@ impl VirtioDevice for Controller { let (ours, theirs) = UnixSeqpacket::pair().expect("pair failed"); - if let Err(e) = self.socket.send(MsgOnSocketRequest::Activate { + if let Err(e) = self.socket.send(Request::Activate { shm: MaybeOwnedFd::new_borrowed(&mem), interrupt: MaybeOwnedFd::new_borrowed(&theirs), interrupt_resample_evt: MaybeOwnedFd::new_borrowed(interrupt.get_resample_evt()), @@ -461,13 +1150,13 @@ impl VirtioDevice for Controller { } fn reset(&mut self) -> bool { - if let Err(e) = self.socket.send(MsgOnSocketRequest::Reset) { + if let Err(e) = self.socket.send(Request::Reset) { error!("failed to send Reset: {}", e); return false; } - match self.socket.recv_msg_on_socket() { - Ok(MsgOnSocketResponse::Reset(result)) => result, + match self.socket.recv() { + Ok(Response::Reset(result)) => result, response => { error!("bad response to Reset: {:?}", response); false @@ -476,12 +1165,12 @@ impl VirtioDevice for Controller { } fn get_device_bars(&mut self, address: PciAddress) -> Vec<PciBarConfiguration> { - if let Err(e) = self.socket.send(BincodeRequest::GetDeviceBars(address)) { + if let Err(e) = self.socket.send(Request::GetDeviceBars(address)) { panic!("failed to send GetDeviceBars: {}", e); } - match self.socket.recv_bincode() { - Ok(BincodeResponse::GetDeviceBars(bars)) => bars, + match self.socket.recv() { + Ok(Response::GetDeviceBars(bars)) => bars, response => { panic!("bad response to GetDeviceBars: {:?}", response); } @@ -489,12 +1178,12 @@ impl VirtioDevice for Controller { } fn get_device_caps(&self) -> Vec<Box<dyn PciCapability>> { - if let Err(e) = self.socket.send(BincodeRequest::GetDeviceCaps) { + if let Err(e) = self.socket.send(Request::GetDeviceCaps) { panic!("failed to send GetDeviceCaps: {}", e); } - match self.socket.recv_bincode() { - Ok(BincodeResponse::GetDeviceCaps(caps)) => caps + match self.socket.recv() { + Ok(Response::GetDeviceCaps(caps)) => caps .into_iter() .map(|cap| Box::new(cap) as Box<dyn PciCapability>) .collect(), diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs index 793246d..e57d4d3 100644 --- a/devices/src/virtio/queue.rs +++ b/devices/src/virtio/queue.rs @@ -200,7 +200,9 @@ impl<'a, 'b> Iterator for AvailIter<'a, 'b> { } } -#[derive(Clone, Debug, MsgOnSocket)] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, MsgOnSocket, Serialize, Deserialize)] /// A virtio queue's parameters. pub struct Queue { /// The maximal size in elements offered by the device diff --git a/msg_socket2/src/error.rs b/msg_socket2/src/error.rs index 902684b..2daa450 100644 --- a/msg_socket2/src/error.rs +++ b/msg_socket2/src/error.rs @@ -1,9 +1,21 @@ +use std::fmt::{self, Display, Formatter}; + #[derive(Debug)] pub enum Error { DataError(bincode::Error), IoError(sys_util::Error), } +impl Display for Error { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + use Error::*; + match self { + DataError(error) => write!(f, "{}", error), + IoError(error) => write!(f, "{}", error), + } + } +} + impl From<bincode::Error> for Error { fn from(error: bincode::Error) -> Self { Self::DataError(error) diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs index bce587a..dc0733d 100644 --- a/msg_socket2/src/socket.rs +++ b/msg_socket2/src/socket.rs @@ -1,9 +1,10 @@ -use bincode::{DefaultOptions, Serializer, Deserializer}; -use std::marker::PhantomData; +use bincode::{DefaultOptions, Deserializer, Serializer}; use std::io::IoSlice; +use std::marker::PhantomData; +use std::os::unix::prelude::*; use sys_util::{net::UnixSeqpacket, ScmSocket}; -use crate::{DeserializerWithFds, DeserializeWithFds, Error, SerializeWithFds, SerializerWithFds}; +use crate::{DeserializeWithFds, DeserializerWithFds, Error, SerializeWithFds, SerializerWithFds}; #[derive(Debug)] pub struct Socket<Send, Recv> { @@ -46,3 +47,9 @@ impl<Send, Recv: for<'de> DeserializeWithFds<'de>> Socket<Send, Recv> { Ok(Recv::deserialize(deserializer_with_fds)?) } } + +impl<Send, Recv> AsRawFd for Socket<Send, Recv> { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs index 08e1aff..1c6b5a3 100644 --- a/msg_socket2/tests/round_trip.rs +++ b/msg_socket2/tests/round_trip.rs @@ -2,7 +2,6 @@ use std::os::unix::prelude::*; use std::fmt::{self, Formatter}; use std::marker::PhantomData; -use std::mem::size_of; use msg_socket2::*; use serde::de::*; @@ -23,9 +22,7 @@ impl SerializeWithFds for Test { where Ser: Serializer, { - let mut state = serializer - .serializer - .serialize_struct("Test", size_of::<Test>())?; + let mut state = serializer.serializer.serialize_struct("Test", 1)?; serializer.fds.push(self.fd); state.skip_field("fd")?; @@ -33,8 +30,7 @@ impl SerializeWithFds for Test { impl<'a> Serialize for SerializableInner<'a> { fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { - let mut state = serializer - .serialize_tuple_struct("Inner", size_of::<Inner>() - size_of::<RawFd>() * 1)?; + let mut state = serializer.serialize_tuple_struct("Inner", 1)?; state.serialize_field(&(self.0).1)?; state.end() } diff --git a/src/linux.rs b/src/linux.rs index 91f2f8c..7bd4679 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -40,7 +40,6 @@ use io_jail::{self, Minijail}; use kvm::*; use msg_socket::{MsgError, MsgReceiver, MsgResult, MsgSender, MsgSocket}; use net_util::{Error as NetError, MacAddress, Tap}; -use poly_msg_socket::PolyMsgSocket; use remain::sorted; use resources::{Alloc, MmioType, SystemAllocator}; use sync::{Condvar, Mutex}; @@ -88,7 +87,7 @@ pub enum Error { BuildVm(<Arch as LinuxArch>::Error), ChownTpmStorage(sys_util::Error), CloneEventFd(sys_util::Error), - ControllerCreate(poly_msg_socket::Error), + ControllerCreate(msg_socket2::Error), CreateAc97(devices::PciDeviceError), CreateConsole(arch::serial::Error), CreateDiskError(disk::Error), @@ -771,7 +770,7 @@ fn create_wayland_device( let mut path = std::env::var("XDG_RUNTIME_DIR").expect("XDG_RUNTIME_DIR missing"); path.push_str("/crosvm-wl.sock"); let seq_socket = UnixSeqpacket::connect(&path).expect("connect failed"); - let msg_socket = PolyMsgSocket::new(seq_socket); + let msg_socket = msg_socket2::Socket::new(seq_socket); let dev = virtio::Controller::create( cfg.wayland_socket_paths.clone(), socket, diff --git a/src/wl.rs b/src/wl.rs index c04ec06..be2ca2e 100644 --- a/src/wl.rs +++ b/src/wl.rs @@ -1,22 +1,19 @@ // SPDX-License-Identifier: BSD-3-Clause use devices::virtio::{ - BincodeRequest, BincodeResponse, InterruptProxy, InterruptProxyEvent, MsgOnSocketRequest, - MsgOnSocketResponse, RemotePciCapability, VirtioDevice, Wl, + InterruptProxy, InterruptProxyEvent, RemotePciCapability, Request, Response, VirtioDevice, Wl, }; use msg_socket::MsgSocket; -use poly_msg_socket::PolyMsgSocket; use std::collections::BTreeMap; use std::fs::remove_file; -use sys_util::{error, net::UnixSeqpacketListener, warn, GuestMemory}; +use sys_util::{error, net::UnixSeqpacketListener, GuestMemory}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] pub use aarch64::arch_memory_regions; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub use x86_64::arch_memory_regions; -type Socket = - PolyMsgSocket<MsgOnSocketResponse, MsgOnSocketRequest, BincodeResponse, BincodeRequest>; +type Socket = msg_socket2::Socket<Response, Request>; fn main() { eprintln!("hello world"); @@ -30,13 +27,13 @@ fn main() { // Receive connection from crosvm. let conn = server.accept().expect("accept failed"); - let msg_socket: Socket = PolyMsgSocket::new(conn); + let msg_socket: Socket = msg_socket2::Socket::new(conn); let (vm_socket, memory_params) = match msg_socket.recv() { - Ok(poly_msg_socket::Value::MsgOnSocket(MsgOnSocketRequest::Create { + Ok(Request::Create { vm_socket, memory_params, - })) => (MsgSocket::new(vm_socket.owned()), memory_params), + }) => (MsgSocket::new(vm_socket.owned()), memory_params), Ok(msg) => { panic!("received unexpected message: {:?}", msg); @@ -53,51 +50,48 @@ fn main() { let mut wl = Wl::new(wayland_paths, vm_socket, None).unwrap(); loop { - use poly_msg_socket::Value::*; match msg_socket.recv() { - Ok(Bincode(BincodeRequest::DebugLabel)) => { + Ok(Request::DebugLabel) => { let result = wl.debug_label(); - if let Err(e) = msg_socket.send(BincodeResponse::DebugLabel(result)) { + if let Err(e) = msg_socket.send(Response::DebugLabel(result)) { panic!("responding to DebugLabel failed: {}", e); } } - Ok(MsgOnSocket(MsgOnSocketRequest::DeviceType)) => { + Ok(Request::DeviceType) => { let result = wl.device_type(); - if let Err(e) = msg_socket.send(MsgOnSocketResponse::DeviceType(result)) { + if let Err(e) = msg_socket.send(Response::DeviceType(result)) { panic!("responding to DeviceType failed: {}", e); } } - Ok(Bincode(BincodeRequest::QueueMaxSizes)) => { + Ok(Request::QueueMaxSizes) => { let result = wl.queue_max_sizes(); - if let Err(e) = msg_socket.send(BincodeResponse::QueueMaxSizes(result)) { + if let Err(e) = msg_socket.send(Response::QueueMaxSizes(result)) { panic!("responding to QueueMaxSizes failed: {}", e); } } - Ok(MsgOnSocket(MsgOnSocketRequest::Features)) => { + Ok(Request::Features) => { let result = wl.features(); - if let Err(e) = msg_socket.send(MsgOnSocketResponse::Features(result)) { + if let Err(e) = msg_socket.send(Response::Features(result)) { panic!("responding to Features failed: {}", e); } } - Ok(MsgOnSocket(MsgOnSocketRequest::AckFeatures(value))) => wl.ack_features(value), + Ok(Request::AckFeatures(value)) => wl.ack_features(value), - Ok(Bincode(BincodeRequest::ReadConfig { offset, len })) => { + Ok(Request::ReadConfig { offset, len }) => { let mut data = vec![0; len]; wl.read_config(offset, &mut data); - if let Err(e) = msg_socket.send(BincodeResponse::ReadConfig(data)) { + if let Err(e) = msg_socket.send(Response::ReadConfig(data)) { panic!("responding to ReadConfig failed: {}", e); } } - Ok(Bincode(BincodeRequest::WriteConfig { offset, ref data })) => { - wl.write_config(offset, data) - } + Ok(Request::WriteConfig { offset, ref data }) => wl.write_config(offset, data), - Ok(MsgOnSocket(MsgOnSocketRequest::Activate { + Ok(Request::Activate { shm, interrupt, interrupt_resample_evt, @@ -105,7 +99,7 @@ fn main() { out_queue, in_queue_evt, out_queue_evt, - })) => { + }) => { let shm = shm.owned(); let regions = arch_memory_regions(memory_params); @@ -128,45 +122,43 @@ fn main() { println!("activated Wl"); } - Ok(MsgOnSocket(MsgOnSocketRequest::Reset)) => { + Ok(Request::Reset) => { let result = wl.reset(); - if let Err(e) = msg_socket.send(MsgOnSocketResponse::Reset(result)) { + if let Err(e) = msg_socket.send(Response::Reset(result)) { panic!("responding to Reset failed: {}", e); } } - Ok(Bincode(BincodeRequest::GetDeviceBars(address))) => { + Ok(Request::GetDeviceBars(address)) => { let result = wl.get_device_bars(address); - if let Err(e) = msg_socket.send(BincodeResponse::GetDeviceBars(result)) { + if let Err(e) = msg_socket.send(Response::GetDeviceBars(result)) { panic!("responding to GetDeviceBars failed: {}", e); } } - Ok(Bincode(BincodeRequest::GetDeviceCaps)) => { + Ok(Request::GetDeviceCaps) => { let result = wl .get_device_caps() .into_iter() .map(|c| RemotePciCapability::from(&*c)) .collect(); - if let Err(e) = msg_socket.send(BincodeResponse::GetDeviceCaps(result)) { + if let Err(e) = msg_socket.send(Response::GetDeviceCaps(result)) { panic!("responding to GetDeviceCaps failed: {}", e); } } - Ok(MsgOnSocket(MsgOnSocketRequest::Kill)) => { + Ok(Request::Kill) => { // Will block until worker shuts down. drop(wl); - if let Err(e) = msg_socket.send(MsgOnSocketResponse::Kill) { + if let Err(e) = msg_socket.send(Response::Kill) { error!("responding to Kill failed: {}", e); } break; } - Ok(MsgOnSocket(msg @ MsgOnSocketRequest::Create { .. })) => { - panic!("unexpected message {:?}", msg) - } + Ok(msg @ Request::Create { .. }) => panic!("unexpected message {:?}", msg), Err(e) => panic!("recv failed: {}", e), } diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml index beba374..d51d2ff 100644 --- a/sys_util/Cargo.toml +++ b/sys_util/Cargo.toml @@ -12,5 +12,6 @@ poll_token_derive = { version = "*", path = "poll_token_derive" } sync = { path = "../sync" } # provided by ebuild syscall_defines = { path = "../syscall_defines" } # provided by ebuild tempfile = { path = "../tempfile" } # provided by ebuild +serde = { version = "*", features = ["derive"] } [workspace] diff --git a/sys_util/src/guest_address.rs b/sys_util/src/guest_address.rs index 1b78e71..48772b2 100644 --- a/sys_util/src/guest_address.rs +++ b/sys_util/src/guest_address.rs @@ -8,8 +8,10 @@ use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; use std::fmt::{self, Display}; use std::ops::{BitAnd, BitOr}; +use serde::{Deserialize, Serialize}; + /// Represents an Address in the guest's memory. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct GuestAddress(pub u64); impl GuestAddress { |