diff options
-rw-r--r-- | Cargo.lock | 13 | ||||
-rw-r--r-- | devices/Cargo.toml | 1 | ||||
-rw-r--r-- | devices/src/lib.rs | 5 | ||||
-rw-r--r-- | devices/src/pci/pci_root.rs | 6 | ||||
-rw-r--r-- | devices/src/virtio/controller.rs | 392 | ||||
-rw-r--r-- | devices/src/virtio/queue.rs | 4 | ||||
-rw-r--r-- | msg_socket2/Cargo.toml | 1 | ||||
-rw-r--r-- | msg_socket2/derive/Cargo.toml | 19 | ||||
-rw-r--r-- | msg_socket2/derive/lib.rs | 264 | ||||
-rw-r--r-- | msg_socket2/src/lib.rs | 6 | ||||
-rw-r--r-- | msg_socket2/src/ser.rs | 600 | ||||
-rw-r--r-- | msg_socket2/src/socket.rs | 11 | ||||
-rw-r--r-- | msg_socket2/tests/round_trip.rs | 54 | ||||
-rw-r--r-- | vm_control/Cargo.toml | 2 | ||||
-rw-r--r-- | vm_control/src/lib.rs | 8 |
15 files changed, 1045 insertions, 341 deletions
diff --git a/Cargo.lock b/Cargo.lock index 5b0f218..9ad31d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,7 @@ dependencies = [ "msg_on_socket_derive 0.1.0", "msg_socket 0.1.0", "msg_socket2 0.1.0", + "msg_socket2_derive 0.1.0", "net_sys 0.1.0", "net_util 0.1.0", "p9 0.1.0", @@ -504,11 +505,21 @@ name = "msg_socket2" version = "0.1.0" dependencies = [ "bincode 1.3.0 (git+https://github.com/alyssais/bincode?branch=from_slice)", + "msg_socket2_derive 0.1.0", "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", "sys_util 0.1.0", ] [[package]] +name = "msg_socket2_derive" +version = "0.1.0" +dependencies = [ + "proc-macro2 1.0.8 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.14 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "net_sys" version = "0.1.0" dependencies = [ @@ -834,6 +845,8 @@ dependencies = [ "kvm 0.1.0", "libc 0.2.44 (registry+https://github.com/rust-lang/crates.io-index)", "msg_socket 0.1.0", + "msg_socket2 0.1.0", + "msg_socket2_derive 0.1.0", "resources 0.1.0", "sys_util 0.1.0", ] diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 830cb86..4d6c4d8 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -35,6 +35,7 @@ linux_input_sys = { path = "../linux_input_sys" } msg_on_socket_derive = { path = "../msg_socket/msg_on_socket_derive" } msg_socket = { path = "../msg_socket" } msg_socket2 = { path = "../msg_socket2" } +msg_socket2_derive = { path = "../msg_socket2/derive" } net_sys = { path = "../net_sys" } net_util = { path = "../net_util" } p9 = { path = "../p9" } diff --git a/devices/src/lib.rs b/devices/src/lib.rs index febb18f..7df9c62 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -52,7 +52,10 @@ pub use self::virtio::VirtioPciDevice; use msg_socket::MsgOnSocket; use serde::{Deserialize, Serialize}; -#[derive(Clone, Copy, Debug, MsgOnSocket, Serialize, Deserialize)] +use msg_socket2_derive::SerializeWithFds; + +#[derive(Clone, Copy, Debug, MsgOnSocket, Serialize, SerializeWithFds, Deserialize)] +#[msg_socket2(strategy = "serde")] pub struct MemoryParams { /// Physical memory size in bytes for the VM. pub size: u64, diff --git a/devices/src/pci/pci_root.rs b/devices/src/pci/pci_root.rs index 0169793..76f9d82 100644 --- a/devices/src/pci/pci_root.rs +++ b/devices/src/pci/pci_root.rs @@ -8,6 +8,7 @@ use std::fmt::{self, Display}; use std::os::unix::io::RawFd; use std::sync::Arc; +use msg_socket2_derive::SerializeWithFds; use serde::{Deserialize, Serialize}; use sync::Mutex; @@ -43,7 +44,10 @@ impl PciDevice for PciRootConfiguration { } /// PCI Device Address, AKA Bus:Device.Function -#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] +#[derive( + Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize, SerializeWithFds, +)] +#[msg_socket2(strategy = "serde")] pub struct PciAddress { pub bus: u8, pub dev: u8, /* u5 */ diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs index 0b10aac..0a59072 100644 --- a/devices/src/virtio/controller.rs +++ b/devices/src/virtio/controller.rs @@ -38,7 +38,9 @@ use std::sync::Arc; use std::thread; use msg_socket::{MsgReceiver, MsgSocket}; -use msg_socket2::{DeserializeWithFds, DeserializerWithFds, SerializeWithFds, SerializerWithFds}; +use msg_socket2::ser::{SerializeStructVariantFds, SerializeTupleVariantFds}; +use msg_socket2::{DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds}; +use msg_socket2_derive::SerializeWithFds; use serde::de::{Deserializer, EnumAccess, SeqAccess, VariantAccess}; use serde::ser::{SerializeStructVariant, SerializeTupleVariant, Serializer}; use serde::{Deserialize, Serialize}; @@ -126,22 +128,15 @@ pub enum Request { } impl SerializeWithFds for Request { - fn serialize<S: SerializerWithFds>( - &self, - mut serializer: S, - ) -> Result<<S::Ser as Serializer>::Ok, <S::Ser as Serializer>::Error> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { use Request::*; match self { Create { - vm_socket, + vm_socket: _, memory_params, } => { - serializer.fds().push(vm_socket.as_raw_fd()); - - let mut sv = serializer - .serializer() - .serialize_struct_variant("Request", 0, "Create", 1)?; + let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 1)?; sv.skip_field("vm_socket")?; sv.serialize_field("memory_params", memory_params)?; @@ -149,80 +144,43 @@ impl SerializeWithFds for Request { sv.end() } - DebugLabel => { - serializer - .serializer() - .serialize_unit_variant("Request", 1, "DebugLabel") - } + DebugLabel => serializer.serialize_unit_variant("Request", 1, "DebugLabel"), - DeviceType => { - serializer - .serializer() - .serialize_unit_variant("Request", 2, "DeviceType") - } - QueueMaxSizes => { - serializer - .serializer() - .serialize_unit_variant("Request", 3, "QueueMaxSizes") - } + DeviceType => serializer.serialize_unit_variant("Request", 2, "DeviceType"), + QueueMaxSizes => serializer.serialize_unit_variant("Request", 3, "QueueMaxSizes"), - Features => serializer - .serializer() - .serialize_unit_variant("Request", 4, "Features"), + Features => serializer.serialize_unit_variant("Request", 4, "Features"), AckFeatures(features) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Request", - 5, - "AckFeatures", - 1, - )?; + let mut tv = serializer.serialize_tuple_variant("Request", 5, "AckFeatures", 1)?; tv.serialize_field(features)?; tv.end() } ReadConfig { offset, len } => { - let mut sv = serializer.serializer().serialize_struct_variant( - "Request", - 6, - "ReadConfig", - 2, - )?; + let mut sv = serializer.serialize_struct_variant("Request", 6, "ReadConfig", 2)?; sv.serialize_field("offset", offset)?; sv.serialize_field("len", len)?; sv.end() } WriteConfig { offset, data } => { - let mut sv = serializer.serializer().serialize_struct_variant( - "Request", - 7, - "WriteConfig", - 2, - )?; + let mut sv = serializer.serialize_struct_variant("Request", 7, "WriteConfig", 2)?; sv.serialize_field("offset", offset)?; sv.serialize_field("data", data)?; sv.end() } Activate { - shm, - interrupt, - interrupt_resample_evt, + shm: _, + interrupt: _, + interrupt_resample_evt: _, in_queue, out_queue, - in_queue_evt, - out_queue_evt, + in_queue_evt: _, + out_queue_evt: _, } => { - serializer.fds().push(shm.as_raw_fd()); - serializer.fds().push(interrupt.as_raw_fd()); - serializer.fds().push(interrupt_resample_evt.as_raw_fd()); - serializer.fds().push(in_queue_evt.as_raw_fd()); - serializer.fds().push(out_queue_evt.as_raw_fd()); - - let mut sv = serializer - .serializer() - .serialize_struct_variant("Request", 8, "Activate", 2)?; + let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 2)?; sv.skip_field("shm")?; sv.skip_field("interrupt")?; @@ -235,30 +193,92 @@ impl SerializeWithFds for Request { sv.end() } - Reset => serializer - .serializer() - .serialize_unit_variant("Request", 9, "Reset"), + Reset => serializer.serialize_unit_variant("Request", 9, "Reset"), GetDeviceBars(address) => { - let mut sv = serializer.serializer().serialize_struct_variant( - "Request", - 10, - "GetDeviceBars", - 1, - )?; + let mut sv = + serializer.serialize_struct_variant("Request", 10, "GetDeviceBars", 1)?; sv.serialize_field("address", address)?; sv.end() } - GetDeviceCaps => { - serializer - .serializer() - .serialize_unit_variant("Request", 11, "GetDeviceCaps") + GetDeviceCaps => serializer.serialize_unit_variant("Request", 11, "GetDeviceCaps"), + + Kill => serializer.serialize_unit_variant("Request", 12, "Kill"), + } + } + + fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + use Request::*; + + match self { + Create { + vm_socket, + memory_params, + } => { + let mut sv = serializer.serialize_struct_variant("Request", 0, "Create", 2)?; + sv.serialize_field("vm_socket", vm_socket)?; + sv.serialize_field("memory_params", memory_params)?; + sv.end() + } + + DebugLabel => serializer.serialize_unit_variant("Request", 1, "DebugLabel"), + DeviceType => serializer.serialize_unit_variant("Request", 2, "DeviceType"), + QueueMaxSizes => serializer.serialize_unit_variant("Request", 3, "QueueMaxSizes"), + Features => serializer.serialize_unit_variant("Request", 4, "Features"), + + AckFeatures(features) => { + let mut tv = serializer.serialize_tuple_variant("Request", 5, "AckFeatures", 1)?; + tv.serialize_field(features)?; + tv.end() + } + + ReadConfig { offset, len } => { + let mut sv = serializer.serialize_struct_variant("Request", 6, "ReadConfig", 2)?; + sv.serialize_field("offset", offset)?; + sv.serialize_field("len", len)?; + sv.end() + } + + WriteConfig { offset, data } => { + let mut sv = serializer.serialize_struct_variant("Request", 7, "WriteConfig", 2)?; + sv.serialize_field("offset", offset)?; + sv.serialize_field("data", data)?; + sv.end() + } + + Activate { + shm, + interrupt, + interrupt_resample_evt, + in_queue, + out_queue, + in_queue_evt, + out_queue_evt, + } => { + let mut sv = serializer.serialize_struct_variant("Request", 8, "Activate", 2)?; + sv.serialize_field("shm", shm)?; + sv.serialize_field("interrupt", interrupt)?; + sv.serialize_field("interrupt_resample_evt", interrupt_resample_evt)?; + sv.serialize_field("in_queue", in_queue)?; + sv.serialize_field("out_queue", out_queue)?; + sv.serialize_field("in_queue_evt", in_queue_evt)?; + sv.serialize_field("out_queue_evt", out_queue_evt)?; + sv.end() + } + + Reset => serializer.serialize_unit_variant("Request", 9, "Reset"), + + GetDeviceBars(address) => { + let mut tv = + serializer.serialize_tuple_variant("Request", 10, "GetDeviceBars", 1)?; + tv.serialize_field(address)?; + tv.end() } - Kill => serializer - .serializer() - .serialize_unit_variant("Request", 12, "Kill"), + GetDeviceCaps => serializer.serialize_unit_variant("Request", 11, "GetDeviceCaps"), + + Kill => serializer.serialize_unit_variant("Request", 12, "Kill"), } } } @@ -378,14 +398,9 @@ impl<'de> DeserializeWithFds<'de> for Request { } (Variant::ReadConfig, variant) => { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } + struct Visitor; - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + impl<'de> serde::de::Visitor<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { @@ -418,19 +433,13 @@ impl<'de> DeserializeWithFds<'de> for Request { } } - let visitor = Visitor { fds: self.fds }; - variant.struct_variant(&["offset", "len"], visitor) + variant.struct_variant(&["offset", "len"], Visitor) } (Variant::WriteConfig, variant) => { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } + struct Visitor; - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + impl<'de> serde::de::Visitor<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { @@ -463,8 +472,7 @@ impl<'de> DeserializeWithFds<'de> for Request { } } - let visitor = Visitor { fds: self.fds }; - variant.struct_variant(&["offset", "data"], visitor) + variant.struct_variant(&["offset", "data"], Visitor) } (Variant::Activate, variant) => { @@ -584,14 +592,9 @@ impl<'de> DeserializeWithFds<'de> for Request { } (Variant::GetDeviceBars, variant) => { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } + struct Visitor; - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { + impl<'de> serde::de::Visitor<'de> for Visitor { type Value = Request; fn expecting(&self, f: &mut Formatter) -> fmt::Result { @@ -616,8 +619,7 @@ impl<'de> DeserializeWithFds<'de> for Request { } } - let visitor = Visitor { fds: self.fds }; - variant.struct_variant(&["bus", "dev"], visitor) + variant.struct_variant(&["bus", "dev"], Visitor) } (Variant::GetDeviceCaps, variant) => { @@ -660,7 +662,8 @@ impl<'de> DeserializeWithFds<'de> for Request { } } -#[derive(Debug)] +#[derive(Debug, Deserialize, Serialize, SerializeWithFds)] +#[msg_socket2(strategy = "serde")] pub enum Response { DebugLabel(String), DeviceType(u32), @@ -673,189 +676,12 @@ pub enum Response { Kill, } -impl SerializeWithFds for Response { - fn serialize<S: SerializerWithFds>( - &self, - serializer: S, - ) -> Result<<S::Ser as Serializer>::Ok, <S::Ser as Serializer>::Error> { - use Response::*; - - match self { - DebugLabel(label) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 0, - "DebugLabel", - 1, - )?; - tv.serialize_field(label)?; - tv.end() - } - - DeviceType(device_type) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 1, - "DeviceType", - 1, - )?; - tv.serialize_field(device_type)?; - tv.end() - } - - QueueMaxSizes(sizes) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 2, - "QueueMaxSizes", - 1, - )?; - tv.serialize_field(sizes)?; - tv.end() - } - - Features(features) => { - let mut tv = serializer - .serializer() - .serialize_tuple_variant("Response", 3, "Features", 1)?; - tv.serialize_field(features)?; - tv.end() - } - - ReadConfig(config) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 4, - "ReadConfig", - 1, - )?; - tv.serialize_field(config)?; - tv.end() - } - - Reset(success) => { - let mut tv = serializer - .serializer() - .serialize_tuple_variant("Response", 5, "Reset", 1)?; - tv.serialize_field(success)?; - tv.end() - } - - GetDeviceBars(bars) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 6, - "GetDeviceBars", - 1, - )?; - tv.serialize_field(bars)?; - tv.end() - } - - GetDeviceCaps(caps) => { - let mut tv = serializer.serializer().serialize_tuple_variant( - "Response", - 7, - "GetDeviceCaps", - 1, - )?; - tv.serialize_field(caps)?; - tv.end() - } - - Kill => serializer - .serializer() - .serialize_unit_variant("Response", 8, "Kill"), - } - } -} - impl<'de> DeserializeWithFds<'de> for Response { fn deserialize<I, De>(deserializer: DeserializerWithFds<I, De>) -> Result<Self, De::Error> where - I: Iterator<Item = RawFd>, De: Deserializer<'de>, { - struct Visitor<'iter, Iter> { - fds: &'iter mut Iter, - } - - impl<'iter, 'de, Iter> serde::de::Visitor<'de> for Visitor<'iter, Iter> - where - Iter: Iterator<Item = RawFd>, - { - type Value = Response; - - fn expecting(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "enum Response") - } - - fn visit_enum<A: EnumAccess<'de>>(self, data: A) -> Result<Self::Value, A::Error> { - #[derive(Debug, Deserialize)] - enum Variant { - DebugLabel, - DeviceType, - QueueMaxSizes, - Features, - ReadConfig, - Reset, - GetDeviceBars, - GetDeviceCaps, - Kill, - } - - match data.variant()? { - (Variant::DebugLabel, variant) => { - Ok(Response::DebugLabel(variant.newtype_variant()?)) - } - (Variant::DeviceType, variant) => { - Ok(Response::DeviceType(variant.newtype_variant()?)) - } - (Variant::QueueMaxSizes, variant) => { - Ok(Response::QueueMaxSizes(variant.newtype_variant()?)) - } - (Variant::Features, variant) => { - Ok(Response::Features(variant.newtype_variant()?)) - } - (Variant::ReadConfig, variant) => { - Ok(Response::ReadConfig(variant.newtype_variant()?)) - } - (Variant::Reset, variant) => Ok(Response::Reset(variant.newtype_variant()?)), - (Variant::GetDeviceBars, variant) => { - Ok(Response::GetDeviceBars(variant.newtype_variant()?)) - } - (Variant::GetDeviceCaps, variant) => { - Ok(Response::GetDeviceCaps(variant.newtype_variant()?)) - } - - (Variant::Kill, variant) => { - variant.unit_variant()?; - Ok(Response::Kill) - } - } - } - } - - let DeserializerWithFds { - mut fds, - deserializer, - } = deserializer; - let visitor = Visitor { fds: &mut fds }; - deserializer.deserialize_enum( - "Response", - &[ - "DebugLabel", - "DeviceType", - "QueueMaxSizes", - "Features", - "ReadConfig", - "Reset", - "GetDeviceBars", - "GetDeviceCaps", - "Kill", - ], - visitor, - ) + Deserialize::deserialize(deserializer.deserializer) } } diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs index e57d4d3..f2310fa 100644 --- a/devices/src/virtio/queue.rs +++ b/devices/src/virtio/queue.rs @@ -7,6 +7,7 @@ use std::num::Wrapping; use std::sync::atomic::{fence, Ordering}; use msg_socket::MsgOnSocket; +use msg_socket2_derive::SerializeWithFds; use sys_util::{error, GuestAddress, GuestMemory}; use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX; @@ -202,7 +203,8 @@ impl<'a, 'b> Iterator for AvailIter<'a, 'b> { use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, MsgOnSocket, Serialize, Deserialize)] +#[derive(Clone, Debug, MsgOnSocket, Serialize, SerializeWithFds, Deserialize)] +#[msg_socket2(strategy = "serde")] /// A virtio queue's parameters. pub struct Queue { /// The maximal size in elements offered by the device diff --git a/msg_socket2/Cargo.toml b/msg_socket2/Cargo.toml index ba167d8..7a7cf73 100644 --- a/msg_socket2/Cargo.toml +++ b/msg_socket2/Cargo.toml @@ -5,6 +5,7 @@ authors = ["Alyssa Ross <hi@alyssa.is>"] edition = "2018" [dependencies] +msg_socket2_derive = { path = "derive" } serde = "1.0.104" sys_util = { path = "../sys_util" } diff --git a/msg_socket2/derive/Cargo.toml b/msg_socket2/derive/Cargo.toml new file mode 100644 index 0000000..19badea --- /dev/null +++ b/msg_socket2/derive/Cargo.toml @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: MIT OR Apache-2.0 +# Copyright 2020, Alyssa Ross + +[package] +name = "msg_socket2_derive" +version = "0.1.0" +authors = ["Alyssa Ross <hi@alyssa.is>"] +license = "MIT OR Apache-2.0" +edition = "2018" +description = "Derive macros for msg_socket2" + +[dependencies] +proc-macro2 = "^1" +quote = "^1" +syn = "^1" + +[lib] +proc-macro = true +path = "lib.rs" diff --git a/msg_socket2/derive/lib.rs b/msg_socket2/derive/lib.rs new file mode 100644 index 0000000..42c17c3 --- /dev/null +++ b/msg_socket2/derive/lib.rs @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// Copyright 2020, Alyssa Ross + +use proc_macro; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use std::convert::{TryFrom, TryInto}; +use std::fmt::{self, Display, Formatter}; +use syn::spanned::Spanned; +use syn::{ + parse_macro_input, Attribute, DeriveInput, GenericParam, Generics, Lifetime, LifetimeDef, Lit, + Meta, MetaList, MetaNameValue, NestedMeta, +}; + +#[proc_macro_derive(SerializeWithFds, attributes(msg_socket2))] +pub fn serialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let impl_for_input = impl_serialize_with_fds(input); + // println!("{}", impl_for_input); + impl_for_input.into() +} + +#[proc_macro_derive(DeserializeWithFds, attributes(msg_socket2))] +pub fn deserialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let impl_for_input = impl_deserialize_with_fds(input); + // println!("{}", impl_for_input); + impl_for_input.into() +} + +#[derive(Debug)] +enum Strategy { + AsRawFd, + Serde, +} + +impl TryFrom<Lit> for Strategy { + type Error = Error; + + fn try_from(lit: Lit) -> Result<Self, Self::Error> { + match lit { + Lit::Str(value) if value.value() == "AsRawFd" => Ok(Strategy::AsRawFd), + Lit::Str(value) if value.value() == "serde" => Ok(Strategy::Serde), + _ => Err(Error::Str { + message: "invalid strategy", + span: Some(lit.span()), + }), + } + } +} + +#[derive(Debug)] +struct Config { + strategy: Strategy, +} + +#[derive(Debug)] +enum Error { + Syn(syn::Error), + Str { + message: &'static str, + span: Option<Span>, + }, +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + use Error::*; + match self { + Syn(error) => write!(f, "{}", error), + Str { message, .. } => write!(f, "{}", message), + } + } +} + +impl Spanned for Error { + fn span(&self) -> Span { + use Error::*; + match self { + Syn(error) => error.span(), + Str { span, .. } => span.unwrap_or(Span::call_site()), + } + } +} + +impl From<syn::Error> for Error { + fn from(error: syn::Error) -> Self { + Self::Syn(error) + } +} + +fn get_strategy(attrs: &[Attribute]) -> Result<Option<Strategy>, Error> { + for attr in attrs { + let (attr_path, nested) = match attr.parse_meta()? { + Meta::List(MetaList { path, nested, .. }) => (path, nested), + _ => continue, + }; + + match attr_path.get_ident() { + Some(ident) if ident == "msg_socket2" => (), + _ => continue, + } + + for meta in nested { + let (name_path, value) = match meta { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => (path, lit), + _ => continue, + }; + + match name_path.get_ident() { + Some(ident) if ident == "strategy" => (), + _ => continue, + } + + return value.try_into().map(Some); + } + } + + Ok(None) +} + +fn get_config(attrs: &[Attribute]) -> Result<Config, Error> { + let strategy = get_strategy(&attrs)?.ok_or_else(|| Error::Str { + message: "missing strategy", + span: None, + })?; + + Ok(Config { strategy }) +} + +fn split_for_deserialize_impl(mut generics: Generics) -> (TokenStream, TokenStream, TokenStream) { + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + // Convert to token streams to take ownership, so generics can be + // mutated to generate impl_generics. + let ty_generics = ty_generics.into_token_stream(); + let where_clause = where_clause.into_token_stream(); + + // Add the deserializer lifetime to the list of generics, then + // calculate the impl_generics. Can't do this before calculating + // ty_generics because then the deserializer lifetime would be + // passed as a lifetime parameter to the type. + // + // FIXME: there must be a better way of doing this. + let lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); + generics.params.insert(0, GenericParam::Lifetime(lifetime)); + let (impl_generics, _, _) = generics.split_for_impl(); + + (impl_generics.into_token_stream(), ty_generics, where_clause) +} + +fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream { + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + quote! { + impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause { + fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> + where + S: ::msg_socket2::Serializer, + { + serializer.serialize_unit() + } + + fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> + where + S: ::msg_socket2::FdSerializer, + { + serializer.serialize_raw_fd(::std::os::unix::io::AsRawFd::as_raw_fd(self)) + } + } + } +} + +fn as_raw_fd_deserialize_impl(input: DeriveInput) -> TokenStream { + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics); + + quote! { + impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics + #where_clause + { + fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error> + where + D: ::msg_socket2::DeserializerWithFds<'__de>, + { + deserializer.deserialize_fd().map(|x| x.specialize()) + } + } + } +} + +fn serde_serialize_impl(input: DeriveInput) -> TokenStream { + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + quote! { + impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause { + fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> + where + S: ::msg_socket2::Serializer, + { + ::msg_socket2::Serialize::serialize(&self, serializer) + } + + fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> + where + S: ::msg_socket2::FdSerializer, + { + serializer.serialize_unit() + } + } + } +} + +fn serde_deserialize_impl(input: DeriveInput) -> TokenStream { + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics); + + quote! { + impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics + #where_clause + { + fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error> + where + D: ::msg_socket2::DeserializerWithFds<'__de>, + { + deserializer.invite_deserialize() + } + } + } +} + +fn impl_serialize_with_fds(input: DeriveInput) -> TokenStream { + let config = match get_config(&input.attrs) { + Ok(config) => config, + Err(error) => { + // FIXME: Tried using compile_error! but it just got + // swallowed. Is that not possible with derive macros? + panic!("{}", error); + } + }; + + match config.strategy { + Strategy::AsRawFd => as_raw_fd_serialize_impl(input), + Strategy::Serde => serde_serialize_impl(input), + } +} + +fn impl_deserialize_with_fds(input: DeriveInput) -> TokenStream { + let config = match get_config(&input.attrs) { + Ok(config) => config, + Err(error) => { + // FIXME: Tried using compile_error! but it just got + // swallowed. Is that not possible with derive macros? + panic!("{}", error); + } + }; + + match config.strategy { + Strategy::AsRawFd => as_raw_fd_deserialize_impl(input), + Strategy::Serde => serde_deserialize_impl(input), + } +} diff --git a/msg_socket2/src/lib.rs b/msg_socket2/src/lib.rs index a1d8ceb..9bb2526 100644 --- a/msg_socket2/src/lib.rs +++ b/msg_socket2/src/lib.rs @@ -25,12 +25,12 @@ mod de; mod error; -mod ser; +pub mod ser; mod socket; -pub(crate) use ser::SerializerWithFdsImpl; +pub(crate) use ser::FdSerializerImpl; pub use de::{DeserializeWithFds, DeserializerWithFds}; pub use error::Error; -pub use ser::{SerializeWithFds, SerializerWithFds}; +pub use ser::{FdSerializer, Serialize, SerializeWithFds, Serializer}; pub use socket::Socket; diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs index 7bffa22..e2b9900 100644 --- a/msg_socket2/src/ser.rs +++ b/msg_socket2/src/ser.rs @@ -1,40 +1,598 @@ -use serde::Serializer; +// Copyright 2020, Alyssa Ross +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of the <organization> nor the +// names of its contributors may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY +// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Portions Copyright Erick Tryzelaar and David Tolnay, +// Licensed under either of Apache License, Version 2.0 +// or MIT license at your option. + +//! Data structure serialization framework for Unix domain sockets. +//! +//! Much like serde::ser, except sending file descriptors is also +//! supported. + +use serde::ser::*; +use std::collections::BTreeMap; +use std::fmt::{self, Display, Formatter}; use std::os::unix::prelude::*; -pub trait SerializerWithFds { - type Ser: Serializer; +pub use serde::ser::{ + Serialize, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTupleStruct, + SerializeTupleVariant, Serializer, +}; + +#[derive(Debug)] +pub(crate) enum Never {} + +impl Display for Never { + fn fmt(&self, _: &mut Formatter) -> fmt::Result { + unreachable!() + } +} + +impl StdError for Never {} + +impl Error for Never { + fn custom<T>(_: T) -> Self { + unreachable!() + } +} + +#[derive(Debug)] +pub(crate) struct Composite<'ser, S> { + serializer: &'ser mut S, +} + +/// Returned from `SerializerWithFds::serialize_seq`. +pub trait SerializeSeqFds { + type Ok; + type Error: Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn end(self) -> Result<Self::Ok, Self::Error>; +} + +impl<'ser, 'fds> SerializeSeqFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; - fn serializer(self) -> Self::Ser; - fn fds(&mut self) -> &mut Vec<RawFd>; + fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// Returned from `SerializerWithFds::serialize_tuple_struct`. +pub trait SerializeTupleStructFds { + type Ok; + type Error: Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn end(self) -> Result<Self::Ok, Self::Error>; +} + +impl<'ser, 'fds> SerializeTupleStructFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// Returned from `SerializerWithFds::serialize_tuple_variant`. +pub trait SerializeTupleVariantFds { + type Ok; + type Error: Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn end(self) -> Result<Self::Ok, Self::Error>; } +impl<'ser, 'fds> SerializeTupleVariantFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// Returned from `SerializerWithFds::serialize_map`. +pub trait SerializeMapFds { + type Ok; + type Error: Error; + + fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error>; + + fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn serialize_entry<K: SerializeWithFds + ?Sized, V: SerializeWithFds + ?Sized>( + &mut self, + key: &K, + value: &V, + ) -> Result<(), Self::Error> { + self.serialize_key(key)?; + self.serialize_value(value) + } + + fn end(self) -> Result<Self::Ok, Self::Error>; +} + +impl<'ser, 'fds> SerializeMapFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + + fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error> { + key.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// Returned from `SerializerWithFds::serialize_struct`. +pub trait SerializeStructFds { + type Ok; + type Error: Error; + + fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> { + let _ = key; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error>; +} + +impl<'ser, 'fds> SerializeStructFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + + fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// Returned from `SerializerWithFds::serialize_struct_variant`. +pub trait SerializeStructVariantFds { + type Ok; + type Error: Error; + + fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> { + let _ = key; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error>; +} + +impl<'ser, 'fds> SerializeStructVariantFds for Composite<'ser, FdSerializerImpl<'fds>> { + type Ok = (); + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + + fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(&mut *self.serializer)?; + Ok(()) + } + + fn end(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } +} + +/// A **data format** that can send file descriptors between Unix processes. +pub trait FdSerializer: Sized { + type Ok; + type Error: Error; + + type SerializeSeqFds: SerializeSeqFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeTupleStructFds: SerializeTupleStructFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeTupleVariantFds: SerializeTupleVariantFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeMapFds: SerializeMapFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeStructFds: SerializeStructFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeStructVariantFds: SerializeStructVariantFds<Ok = Self::Ok, Error = Self::Error>; + + fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error>; + + fn serialize_none(self) -> Result<Self::Ok, Self::Error>; + + fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error> + where + T: SerializeWithFds + ?Sized; + + fn serialize_unit(self) -> Result<Self::Ok, Self::Error>; + + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result<Self::Ok, Self::Error>; + + fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error>; + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleStructFds, Self::Error>; + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleVariantFds, Self::Error>; + + fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMapFds, Self::Error>; + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result<Self::SerializeStructFds, Self::Error>; + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeStructVariantFds, Self::Error>; +} + +#[derive(Debug)] +pub(crate) struct FdSerializerImpl<'fds> { + pub fds: &'fds mut Vec<RawFd>, +} + +impl<'ser, 'fds> FdSerializer for &'ser mut FdSerializerImpl<'fds> { + type Ok = (); + type Error = Never; + + type SerializeSeqFds = Composite<'ser, FdSerializerImpl<'fds>>; + type SerializeTupleStructFds = Composite<'ser, FdSerializerImpl<'fds>>; + type SerializeTupleVariantFds = Composite<'ser, FdSerializerImpl<'fds>>; + type SerializeMapFds = Composite<'ser, FdSerializerImpl<'fds>>; + type SerializeStructFds = Composite<'ser, FdSerializerImpl<'fds>>; + type SerializeStructVariantFds = Composite<'ser, FdSerializerImpl<'fds>>; + + fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error> { + self.fds.push(value); + Ok(()) + } + + fn serialize_none(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } + + fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error> + where + T: SerializeWithFds + ?Sized, + { + value.serialize_fds(self) + } + + fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { + Ok(()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result<Self::Ok, Self::Error> { + Ok(()) + } + + fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error> { + Ok(Composite { serializer: self }) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleStructFds, Self::Error> { + Ok(Composite { serializer: self }) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeStructVariantFds, Self::Error> { + Ok(Composite { serializer: self }) + } + + fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error> { + Ok(Composite { serializer: self }) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleStructFds, Self::Error> { + Ok(Composite { serializer: self }) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeStructVariantFds, Self::Error> { + Ok(Composite { serializer: self }) + } +} + +/// A **data structure** that can be sent over a Unix domain socket. pub trait SerializeWithFds { - fn serialize<Ser: SerializerWithFds>( - &self, - serializer: Ser, - ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error>; + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>; + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>; +} + +macro_rules! serialize_impl { + ($type:ty) => { + impl SerializeWithFds for $type { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + Serialize::serialize(&self, serializer) + } + + fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_unit() + } + } + }; +} + +macro_rules! fd_impl_body { + () => { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_unit() + } + + fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_raw_fd(AsRawFd::as_raw_fd(self)) + } + }; +} + +// FIXME: it would be nice if these two cases could be consolidated +// with a better pattern, and fd_impl_body!() could be inlined. +macro_rules! fd_impl { + ($type:ty) => { + impl SerializeWithFds for $type { + fd_impl_body!(); + } + }; + + ($type:ty, $($args:tt),+) => { + impl<$($args),*> SerializeWithFds for $type { + fd_impl_body!(); + } + }; } +serialize_impl!(u8); +serialize_impl!(u16); +serialize_impl!(u32); +serialize_impl!(u64); +serialize_impl!(usize); + +serialize_impl!(String); +serialize_impl!(std::path::PathBuf); + +fd_impl!(std::fs::File); +fd_impl!(std::io::Stderr); +fd_impl!(std::io::StderrLock<'a>, 'a); +fd_impl!(std::io::Stdin); +fd_impl!(std::io::StdinLock<'a>, 'a); +fd_impl!(std::io::Stdout); +fd_impl!(std::io::StdoutLock<'a>, 'a); +fd_impl!(std::net::TcpListener); +fd_impl!(std::net::TcpStream); +fd_impl!(std::net::UdpSocket); +fd_impl!(std::os::unix::net::UnixDatagram); +fd_impl!(std::os::unix::net::UnixListener); +fd_impl!(std::os::unix::net::UnixStream); +fd_impl!(std::process::ChildStderr); +fd_impl!(std::process::ChildStdin); +fd_impl!(std::process::ChildStdout); + +impl<T: SerializeWithFds> SerializeWithFds for Option<T> { + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + match self { + Some(value) => serializer.serialize_some(&SerializeAdapter(value)), + None => serializer.serialize_none(), + } + } + + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + match self { + Some(value) => serializer.serialize_some(value), + None => serializer.serialize_none(), + } + } +} + +impl<T: SerializeWithFds> SerializeWithFds for Vec<T> { + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for element in self { + seq.serialize_element(&SerializeAdapter(element))?; + } + seq.end() + } + + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for element in self { + seq.serialize_element(element)?; + } + seq.end() + } +} + +impl<K: SerializeWithFds, V: SerializeWithFds> SerializeWithFds for BTreeMap<K, V> { + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut map = serializer.serialize_map(Some(self.len()))?; + for (k, v) in self { + map.serialize_entry(&SerializeAdapter(k), &SerializeAdapter(v))?; + } + map.end() + } + + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut map = serializer.serialize_map(Some(self.len()))?; + for (k, v) in self { + map.serialize_entry(k, v)?; + } + map.end() + } +} + +/// A convenience for serializing a `RawFd` field in a composite type, +/// using a `serialize_raw_fd` method. +/// +/// `RawFd` cannot directly implement `SerializeWithFds` because it is +/// a type alias to `i16`, and so it would be ambiguous whether to +/// serialize an `i16` value as an integer or as a file descriptor. +/// +/// In general, using `RawFd` should be avoided in favour of types +/// like `File` that represent fd ownership. `SerializeRawFd` borrows +/// a `RawFd` even though `RawFd` is `Copy` to prevent its misuse as a +/// generic container for a `RawFd`, which it is not. #[derive(Debug)] -pub struct SerializerWithFdsImpl<'fds, Ser> { - serializer: Ser, - fds: &'fds mut Vec<RawFd>, +pub struct SerializeRawFd<'a>(&'a RawFd); + +impl<'a> SerializeRawFd<'a> { + pub fn new(fd: &'a RawFd) -> Self { + Self(fd) + } } -impl<'fds, Ser> SerializerWithFdsImpl<'fds, Ser> { - pub fn new(fds: &'fds mut Vec<RawFd>, serializer: Ser) -> Self { - Self { serializer, fds } +impl<'a> SerializeWithFds for SerializeRawFd<'a> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_unit() + } + + fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + serializer.serialize_raw_fd(*self.0) } } -impl<'fds, Ser: Serializer> SerializerWithFds for SerializerWithFdsImpl<'fds, Ser> { - type Ser = Ser; +/// A wrapper struct to prevent `SerializeWithFds` implementations +/// directly implementing `Serialize` for their `serialize` method. +/// +/// Just doing `Serialize` on a `SerializeWithFds` implmentation would +/// return a partially serialized value, because the file descriptors +/// won't have been serialized. +/// +/// This should therefore only be used to implement +/// `SerializeWithFds`, where the file descriptors are then seperately +/// serialized in `serialize_fds`. +#[derive(Debug)] +pub struct SerializeAdapter<'a, T: ?Sized>(&'a T); - fn serializer(self) -> Self::Ser { - self.serializer +impl<'a, T: ?Sized> SerializeAdapter<'a, T> { + pub fn new(value: &'a T) -> Self { + Self(value) } +} - fn fds(&mut self) -> &mut Vec<RawFd> { - &mut self.fds +impl<'a, T: SerializeWithFds + ?Sized> Serialize for SerializeAdapter<'a, T> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + self.0.serialize(serializer) } } diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs index 4e75e82..b836365 100644 --- a/msg_socket2/src/socket.rs +++ b/msg_socket2/src/socket.rs @@ -4,9 +4,7 @@ use std::marker::PhantomData; use std::os::unix::prelude::*; use sys_util::{net::UnixSeqpacket, ScmSocket}; -use crate::{ - DeserializeWithFds, DeserializerWithFds, Error, SerializeWithFds, SerializerWithFdsImpl, -}; +use crate::{DeserializeWithFds, DeserializerWithFds, Error, FdSerializerImpl, SerializeWithFds}; #[derive(Debug)] pub struct Socket<Send, Recv> { @@ -28,9 +26,10 @@ impl<Send: SerializeWithFds, Recv> Socket<Send, Recv> { let mut bytes: Vec<u8> = vec![]; let mut fds: Vec<RawFd> = vec![]; - let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new()); - let serializer_with_fds = SerializerWithFdsImpl::new(&mut fds, &mut serializer); - value.serialize(serializer_with_fds)?; + value.serialize(&mut Serializer::new(&mut bytes, DefaultOptions::new()))?; + value + .serialize_fds(&mut FdSerializerImpl { fds: &mut fds }) + .unwrap(); self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?; diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs index efec94b..1bb6636 100644 --- a/msg_socket2/tests/round_trip.rs +++ b/msg_socket2/tests/round_trip.rs @@ -3,14 +3,33 @@ use std::os::unix::prelude::*; use std::fmt::{self, Formatter}; use std::marker::PhantomData; -use msg_socket2::*; -use serde::de::*; -use serde::ser::*; +use msg_socket2::{ + ser::{ + SerializeAdapter, SerializeRawFd, SerializeStruct, SerializeStructFds, + SerializeTupleStruct, SerializeTupleStructFds, + }, + DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds, Serializer, Socket, +}; +use serde::de::{Deserializer, SeqAccess}; use sys_util::net::UnixSeqpacket; #[derive(Debug)] struct Inner(RawFd, u16); +impl SerializeWithFds for Inner { + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut state = serializer.serialize_tuple_struct("Inner", 1)?; + state.serialize_field(&self.1)?; + state.end() + } + + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut state = serializer.serialize_tuple_struct("Inner", 1)?; + state.serialize_field(&SerializeRawFd::new(&self.0))?; + state.end() + } +} + #[derive(Debug)] struct Test { fd: RawFd, @@ -18,26 +37,17 @@ struct Test { } impl SerializeWithFds for Test { - fn serialize<Ser: SerializerWithFds>( - &self, - mut serializer: Ser, - ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error> { - serializer.fds().push(self.fd); - serializer.fds().push(self.inner.0); - - struct SerializableInner<'a>(&'a Inner); - - impl<'a> Serialize for SerializableInner<'a> { - fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { - let mut state = serializer.serialize_tuple_struct("Inner", 1)?; - state.serialize_field(&(self.0).1)?; - state.end() - } - } - - let mut state = serializer.serializer().serialize_struct("Test", 1)?; + fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut state = serializer.serialize_struct("Test", 1)?; state.skip_field("fd")?; - state.serialize_field("inner", &SerializableInner(&self.inner))?; + state.serialize_field("inner", &SerializeAdapter::new(&self.inner))?; + state.end() + } + + fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + let mut state = serializer.serialize_struct("Test", 2)?; + state.serialize_field("fd", &SerializeRawFd::new(&self.fd))?; + state.serialize_field("inner", &self.inner)?; state.end() } } diff --git a/vm_control/Cargo.toml b/vm_control/Cargo.toml index 564aea1..3a6ec3f 100644 --- a/vm_control/Cargo.toml +++ b/vm_control/Cargo.toml @@ -9,5 +9,7 @@ data_model = { path = "../data_model" } kvm = { path = "../kvm" } libc = "*" msg_socket = { path = "../msg_socket" } +msg_socket2 = { path = "../msg_socket2" } +msg_socket2_derive = { path = "../msg_socket2/derive" } resources = { path = "../resources" } sys_util = { path = "../sys_util" } diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index 3d92c8b..3ec6be7 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -20,20 +20,22 @@ use libc::{EINVAL, EIO, ENODEV}; use kvm::{IrqRoute, IrqSource, Vm}; use msg_socket::{MsgError, MsgOnSocket, MsgResult, MsgSocket, UnixSeqpacketExt}; +use msg_socket2_derive::SerializeWithFds; use resources::{Alloc, GpuMemoryDesc, MmioType, SystemAllocator}; use sys_util::net::UnixSeqpacket; use sys_util::{error, Error as SysError, EventFd, GuestAddress, MemoryMapping, MmapError, Result}; /// A data structure that either owns or borrows a file descriptor. -#[derive(Debug)] -pub enum MaybeOwnedFd<Owned> { +#[derive(SerializeWithFds, Debug)] +#[msg_socket2(strategy = "AsRawFd")] +pub enum MaybeOwnedFd<Owned: AsRawFd> { /// Owned by this enum variant, and will be destructed automatically if not moved out. Owned(Owned), /// A file descriptor borrwed by this enum. Borrowed(RawFd), } -impl<Owned> MaybeOwnedFd<Owned> { +impl<Owned: AsRawFd> MaybeOwnedFd<Owned> { pub fn new_borrowed(fd: &dyn AsRawFd) -> Self { Self::Borrowed(fd.as_raw_fd()) } |