diff options
author | Alyssa Ross <hi@alyssa.is> | 2020-03-26 09:03:39 +0000 |
---|---|---|
committer | Alyssa Ross <hi@alyssa.is> | 2020-06-15 09:37:19 +0000 |
commit | 353b1d9091b9095282463f36e26643506e2d2897 (patch) | |
tree | 714c613f9140cfebd45a3308ff2f035b0bb10958 | |
parent | e682f11cd0d063fe9111b13ab8ecff676592acdb (diff) | |
download | crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar.gz crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar.bz2 crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar.lz crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar.xz crosvm-353b1d9091b9095282463f36e26643506e2d2897.tar.zst crosvm-353b1d9091b9095282463f36e26643506e2d2897.zip |
use lifetimes in serialization to prevent closing
-rw-r--r-- | Cargo.lock | 1 | ||||
-rw-r--r-- | devices/src/virtio/controller.rs | 5 | ||||
-rw-r--r-- | devices/src/virtio/wl.rs | 35 | ||||
-rw-r--r-- | msg_socket/Cargo.toml | 1 | ||||
-rw-r--r-- | msg_socket/src/lib.rs | 11 | ||||
-rw-r--r-- | msg_socket2/derive/lib.rs | 61 | ||||
-rw-r--r-- | msg_socket2/src/ser.rs | 174 | ||||
-rw-r--r-- | msg_socket2/src/socket.rs | 16 | ||||
-rw-r--r-- | msg_socket2/tests/round_trip.rs | 12 |
9 files changed, 174 insertions, 142 deletions
diff --git a/Cargo.lock b/Cargo.lock index 8e079d2..ba5218a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -495,6 +495,7 @@ dependencies = [ "futures 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.44 (registry+https://github.com/rust-lang/crates.io-index)", "msg_on_socket_derive 0.1.0", + "msg_socket2 0.1.0", "sync 0.1.0", "sys_util 0.1.0", ] diff --git a/devices/src/virtio/controller.rs b/devices/src/virtio/controller.rs index 644ece1..ba2543f 100644 --- a/devices/src/virtio/controller.rs +++ b/devices/src/virtio/controller.rs @@ -201,7 +201,10 @@ impl SerializeWithFds for Request { } } - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { use Request::*; match self { diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index 7a436f5..5a3505b 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -53,6 +53,7 @@ use data_model::*; use msg_socket::{MsgError, MsgReceiver, MsgSender}; use msg_socket2::de::VisitorWithFds; +use msg_socket2::ser::SerializeAdapter; #[cfg(feature = "wl-dmabuf")] use resources::GpuMemoryDesc; #[cfg(feature = "wl-dmabuf")] @@ -1540,10 +1541,7 @@ pub struct Wl { use msg_socket2::{ de::{DeserializeWithFds, DeserializerWithFds}, - ser::{ - FdSerializer, SerializeRawFd, SerializeStruct, SerializeStructFds, SerializeWithFds, - Serializer, - }, + ser::{FdSerializer, SerializeStruct, SerializeStructFds, SerializeWithFds, Serializer}, }; use std::fmt::Formatter; @@ -1559,29 +1557,22 @@ impl SerializeWithFds for Params { fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { let mut state = serializer.serialize_struct("Params", 3)?; state.serialize_field("wayland_paths", &self.wayland_paths)?; - state.serialize_field("vm_socket", &())?; - - let resource_bridge_marker = self.resource_bridge.as_ref().map(|_| ()); - state.serialize_field("resource_bridge", &resource_bridge_marker)?; - + state.serialize_field("vm_socket", &SerializeAdapter::new(&self.vm_socket))?; + state.serialize_field( + "resource_bridge", + &SerializeAdapter::new(&self.resource_bridge), + )?; state.end() } - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { let mut state = serializer.serialize_struct("Params", 3)?; - state.serialize_field("wayland_paths", &self.wayland_paths)?; - state.serialize_field( - "vm_socket", - &SerializeRawFd::new(&self.vm_socket.as_raw_fd()), - )?; - - let bridge_fd = self.resource_bridge.as_ref().map(AsRawFd::as_raw_fd); - state.serialize_field( - "resource_bridge", - &bridge_fd.as_ref().map(SerializeRawFd::new), - )?; - + state.serialize_field("vm_socket", &self.vm_socket)?; + state.serialize_field("resource_bridge", &self.resource_bridge)?; state.end() } } diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml index 80eba0b..030b7d3 100644 --- a/msg_socket/Cargo.toml +++ b/msg_socket/Cargo.toml @@ -10,5 +10,6 @@ data_model = { path = "../data_model" } futures = "*" libc = "*" msg_on_socket_derive = { path = "msg_on_socket_derive" } +msg_socket2 = { path = "../msg_socket2" } sys_util = { path = "../sys_util" } sync = { path = "../sync" } diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs index be86070..5dedac8 100644 --- a/msg_socket/src/lib.rs +++ b/msg_socket/src/lib.rs @@ -6,7 +6,7 @@ mod msg_on_socket; use std::io::{IoSlice, Result}; use std::marker::PhantomData; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::prelude::*; use std::pin::Pin; use std::task::{Context, Poll}; @@ -14,6 +14,7 @@ use futures::Stream; use libc::{EWOULDBLOCK, O_NONBLOCK}; use cros_async::add_read_waker; +use msg_socket2::{DeserializeWithFds, SerializeWithFds}; use sys_util::{ add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError, ScmSocket, @@ -33,6 +34,8 @@ pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>( } /// Bidirection sock that support both send and recv. +#[derive(SerializeWithFds, DeserializeWithFds)] +#[msg_socket2(strategy = "AsRawFd")] pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> { sock: UnixSeqpacket, _i: PhantomData<I>, @@ -99,6 +102,12 @@ impl<I: MsgOnSocket, O: MsgOnSocket> AsRawFd for MsgSocket<I, O> { } } +impl<I: MsgOnSocket, O: MsgOnSocket> FromRawFd for MsgSocket<I, O> { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self::new(UnixSeqpacket::from_raw_fd(fd)) + } +} + impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Sender<M> { fn as_ref(&self) -> &UnixSeqpacket { &self.sock diff --git a/msg_socket2/derive/lib.rs b/msg_socket2/derive/lib.rs index 42c17c3..dacfc02 100644 --- a/msg_socket2/derive/lib.rs +++ b/msg_socket2/derive/lib.rs @@ -8,8 +8,8 @@ use std::convert::{TryFrom, TryInto}; use std::fmt::{self, Display, Formatter}; use syn::spanned::Spanned; use syn::{ - parse_macro_input, Attribute, DeriveInput, GenericParam, Generics, Lifetime, LifetimeDef, Lit, - Meta, MetaList, MetaNameValue, NestedMeta, + parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, DeriveInput, GenericParam, + Generics, Lifetime, LifetimeDef, Lit, Meta, MetaList, MetaNameValue, NestedMeta, }; #[proc_macro_derive(SerializeWithFds, attributes(msg_socket2))] @@ -128,7 +128,16 @@ fn get_config(attrs: &[Attribute]) -> Result<Config, Error> { Ok(Config { strategy }) } -fn split_for_deserialize_impl(mut generics: Generics) -> (TokenStream, TokenStream, TokenStream) { +/// Like `syn::Generics::split_for_impl`, but a list of lifetimes can +/// be added to the start of the generics list to go next to the +/// `impl` keyword, but not the generics list to go after the type. +/// +/// This allows for generics that are required for a trait lifetime, +/// rather than for the type its being implemented for. +fn split_for_impl_with_trait_lifetimes( + lifetimes: &'static [&'static str], + mut generics: Generics, +) -> (TokenStream, TokenStream, TokenStream) { let (_, ty_generics, where_clause) = generics.split_for_impl(); // Convert to token streams to take ownership, so generics can be @@ -136,14 +145,20 @@ fn split_for_deserialize_impl(mut generics: Generics) -> (TokenStream, TokenStre let ty_generics = ty_generics.into_token_stream(); let where_clause = where_clause.into_token_stream(); - // Add the deserializer lifetime to the list of generics, then - // calculate the impl_generics. Can't do this before calculating - // ty_generics because then the deserializer lifetime would be - // passed as a lifetime parameter to the type. + // Add the trait lifetimes to the list of generics, then calculate + // the impl_generics. Can't do this before calculating + // ty_generics because then the trait lifetimes would be passed as + // lifetime parameters to the type. // // FIXME: there must be a better way of doing this. - let lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); - generics.params.insert(0, GenericParam::Lifetime(lifetime)); + let defs: Punctuated<_, Comma> = lifetimes + .iter() + .map(|name| Lifetime::new(name, Span::call_site())) + .map(LifetimeDef::new) + .map(GenericParam::Lifetime) + .chain(generics.params.into_iter()) + .collect(); + generics.params = defs; let (impl_generics, _, _) = generics.split_for_impl(); (impl_generics.into_token_stream(), ty_generics, where_clause) @@ -154,7 +169,9 @@ fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream { let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); quote! { - impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause { + impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics + #where_clause + { fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> where S: ::msg_socket2::Serializer, @@ -162,11 +179,11 @@ fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream { serializer.serialize_unit() } - fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> - where - S: ::msg_socket2::FdSerializer, - { - serializer.serialize_raw_fd(::std::os::unix::io::AsRawFd::as_raw_fd(self)) + fn serialize_fds<'__fds, S: ::msg_socket2::FdSerializer<'__fds>>( + &'__fds self, + serializer: S, + ) -> ::std::result::Result<S::Ok, S::Error> { + serializer.serialize_raw_fd(self) } } } @@ -174,7 +191,8 @@ fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream { fn as_raw_fd_deserialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; - let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics); + let (impl_generics, ty_generics, where_clause) = + split_for_impl_with_trait_lifetimes(&["'__de"], input.generics); quote! { impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics @@ -203,10 +221,10 @@ fn serde_serialize_impl(input: DeriveInput) -> TokenStream { ::msg_socket2::Serialize::serialize(&self, serializer) } - fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> - where - S: ::msg_socket2::FdSerializer, - { + fn serialize_fds<'__fds, S: ::msg_socket2::FdSerializer<'__fds>>( + &'__fds self, + serializer: S, + ) -> ::std::result::Result<S::Ok, S::Error> { serializer.serialize_unit() } } @@ -215,7 +233,8 @@ fn serde_serialize_impl(input: DeriveInput) -> TokenStream { fn serde_deserialize_impl(input: DeriveInput) -> TokenStream { let name = input.ident; - let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics); + let (impl_generics, ty_generics, where_clause) = + split_for_impl_with_trait_lifetimes(&["'__de"], input.generics); quote! { impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs index ed85621..1391d12 100644 --- a/msg_socket2/src/ser.rs +++ b/msg_socket2/src/ser.rs @@ -42,7 +42,7 @@ pub(crate) struct Composite<'ser, S> { } /// Like `SerializeSeq`, but accepts file descriptors instead of data. -pub trait SerializeSeqFds { +pub trait SerializeSeqFds<'fds> { /// Like `SerializeSeq::Ok`. type Ok; @@ -51,7 +51,7 @@ pub trait SerializeSeqFds { /// Like `SerializeSeq::serialize_element`, but `value` is a /// `SerializeWithFds` instead of a `Serialize`. - fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_element<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -59,11 +59,11 @@ pub trait SerializeSeqFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeSeqFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeSeqFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_element<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -77,7 +77,7 @@ impl<'ser, 'fds> SerializeSeqFds for Composite<'ser, FdSerializerImpl<'fds>> { } /// Like `SerializeTupleStruct`, but accepts file descriptors instead of data. -pub trait SerializeTupleStructFds { +pub trait SerializeTupleStructFds<'fds> { /// Like `SerializeTupleStruct::Ok`. type Ok; @@ -86,7 +86,7 @@ pub trait SerializeTupleStructFds { /// Like `SerializeTupleStruct::serialize_field`, but `value` is a /// `SerializeWithFds` instead of a `Serialize`. - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -94,11 +94,11 @@ pub trait SerializeTupleStructFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeTupleStructFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeTupleStructFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -113,7 +113,7 @@ impl<'ser, 'fds> SerializeTupleStructFds for Composite<'ser, FdSerializerImpl<'f /// Like `SerializeTupleVariant`, but accepts file descriptors as well /// as data. -pub trait SerializeTupleVariantFds { +pub trait SerializeTupleVariantFds<'fds> { /// Like `SerializeTupleVariant::Ok`. type Ok; @@ -122,7 +122,7 @@ pub trait SerializeTupleVariantFds { /// Like `SerializeTupleVariant::serialize_field`, but `value` is /// a `SerializeWithFds` instead of a `Serialize`. - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -130,11 +130,11 @@ pub trait SerializeTupleVariantFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeTupleVariantFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeTupleVariantFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -148,7 +148,7 @@ impl<'ser, 'fds> SerializeTupleVariantFds for Composite<'ser, FdSerializerImpl<' } /// Like `SerializeMap`, but provides file descriptors instead of data. -pub trait SerializeMapFds { +pub trait SerializeMapFds<'fds> { /// Like `SerializeMap::Ok`. type Ok; @@ -156,10 +156,13 @@ pub trait SerializeMapFds { type Error: Error; /// Like `SerializeMap::serialize_key`, but `key` is a `SerializeWithFds` instead of a `Serialize`. - fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error>; + fn serialize_key<T: SerializeWithFds + ?Sized>( + &mut self, + key: &'fds T, + ) -> Result<(), Self::Error>; /// Like `SerializeMap::serialize_value`, but `value` is a `SerializeWithFds` instead of a `Serialize`. - fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_value<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -167,8 +170,8 @@ pub trait SerializeMapFds { /// `SerializeWithFds` instead of `Serialize`. fn serialize_entry<K: SerializeWithFds + ?Sized, V: SerializeWithFds + ?Sized>( &mut self, - key: &K, - value: &V, + key: &'fds K, + value: &'fds V, ) -> Result<(), Self::Error> { self.serialize_key(key)?; self.serialize_value(value) @@ -178,16 +181,19 @@ pub trait SerializeMapFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeMapFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeMapFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error> { + fn serialize_key<T>(&mut self, key: &'fds T) -> Result<(), Self::Error> + where + T: SerializeWithFds + ?Sized, + { key.serialize_fds(&mut *self.serializer)?; Ok(()) } - fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_value<T>(&mut self, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -202,7 +208,7 @@ impl<'ser, 'fds> SerializeMapFds for Composite<'ser, FdSerializerImpl<'fds>> { /// Like `SerializeStruct`, but provides file descriptors instead of /// data. -pub trait SerializeStructFds { +pub trait SerializeStructFds<'fds> { /// Like `SerializeStruct::Ok`. type Ok; @@ -211,7 +217,7 @@ pub trait SerializeStructFds { /// Like `SerializeStruct::serialize_field`, but `value` is a /// `SerializeWithFds` instead of a `Serialize`. - fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, key: &'static str, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -225,11 +231,11 @@ pub trait SerializeStructFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeStructFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeStructFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, _key: &'static str, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -244,7 +250,7 @@ impl<'ser, 'fds> SerializeStructFds for Composite<'ser, FdSerializerImpl<'fds>> /// Like `SerializeStructVariant`, but provides file descriptors /// instead of data. -pub trait SerializeStructVariantFds { +pub trait SerializeStructVariantFds<'fds> { /// Like `SerializeStructVariant::Ok`. type Ok; @@ -253,7 +259,7 @@ pub trait SerializeStructVariantFds { /// Like `SerializeStructVariant::serialize_field`, but `key` and /// `value` are `SerializeWithFds` instead of `Serialize`. - fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, key: &'static str, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized; @@ -267,11 +273,11 @@ pub trait SerializeStructVariantFds { fn end(self) -> Result<Self::Ok, Self::Error>; } -impl<'ser, 'fds> SerializeStructVariantFds for Composite<'ser, FdSerializerImpl<'fds>> { +impl<'ser, 'fds: 'ser> SerializeStructVariantFds<'fds> for Composite<'ser, FdSerializerImpl<'fds>> { type Ok = (); - type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error; + type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer<'fds>>::Error; - fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + fn serialize_field<T>(&mut self, _key: &'static str, value: &'fds T) -> Result<(), Self::Error> where T: SerializeWithFds + ?Sized, { @@ -285,7 +291,7 @@ impl<'ser, 'fds> SerializeStructVariantFds for Composite<'ser, FdSerializerImpl< } /// Like `Serializer`, but accepts file descriptors instead of data. -pub trait FdSerializer: Sized { +pub trait FdSerializer<'fds>: Sized { /// Like `Serializer::Ok`. type Ok; @@ -294,37 +300,45 @@ pub trait FdSerializer: Sized { /// Like `Serializer::SerializeSeq`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeSeqFds: SerializeSeqFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeSeqFds: SerializeSeqFds<'fds, Ok = Self::Ok, Error = Self::Error>; /// Like `Serializer::SerializeTupleStruct`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeTupleStructFds: SerializeTupleStructFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeTupleStructFds: SerializeTupleStructFds<'fds, Ok = Self::Ok, Error = Self::Error>; /// Like `Serializer::SerializeTupleVariant`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeTupleVariantFds: SerializeTupleVariantFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeTupleVariantFds: SerializeTupleVariantFds< + 'fds, + Ok = Self::Ok, + Error = Self::Error, + >; /// Like `Serializer::SerializeMap`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeMapFds: SerializeMapFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeMapFds: SerializeMapFds<'fds, Ok = Self::Ok, Error = Self::Error>; /// Like `Serializer::SerializeStruct`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeStructFds: SerializeStructFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeStructFds: SerializeStructFds<'fds, Ok = Self::Ok, Error = Self::Error>; /// Like `Serializer::SerializeStructVariant`, but serializes /// `SerializeWithFds` instead of `Serialize`. - type SerializeStructVariantFds: SerializeStructVariantFds<Ok = Self::Ok, Error = Self::Error>; + type SerializeStructVariantFds: SerializeStructVariantFds< + 'fds, + Ok = Self::Ok, + Error = Self::Error, + >; /// Serialize a file descriptor. - fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error>; + fn serialize_raw_fd(self, value: &'fds dyn AsRawFd) -> Result<Self::Ok, Self::Error>; /// Like `Serializer::serialize_none`. fn serialize_none(self) -> Result<Self::Ok, Self::Error>; /// Like `Serializer::serialize_some`, but `value` is a /// `SerializeWithFds` instead of a `Serialize`. - fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error> + fn serialize_some<T>(self, value: &'fds T) -> Result<Self::Ok, Self::Error> where T: SerializeWithFds + ?Sized; @@ -390,12 +404,11 @@ pub trait FdSerializer: Sized { ) -> Result<Self::SerializeStructVariantFds, Self::Error>; } -#[derive(Debug)] pub(crate) struct FdSerializerImpl<'fds> { - pub fds: &'fds mut Vec<RawFd>, + pub fds: Vec<&'fds dyn AsRawFd>, } -impl<'ser, 'fds> FdSerializer for &'ser mut FdSerializerImpl<'fds> { +impl<'ser, 'fds: 'ser> FdSerializer<'fds> for &'ser mut FdSerializerImpl<'fds> { type Ok = (); type Error = Never; @@ -406,7 +419,7 @@ impl<'ser, 'fds> FdSerializer for &'ser mut FdSerializerImpl<'fds> { type SerializeStructFds = Composite<'ser, FdSerializerImpl<'fds>>; type SerializeStructVariantFds = Composite<'ser, FdSerializerImpl<'fds>>; - fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error> { + fn serialize_raw_fd(self, value: &'fds dyn AsRawFd) -> Result<Self::Ok, Self::Error> { self.fds.push(value); Ok(()) } @@ -415,7 +428,7 @@ impl<'ser, 'fds> FdSerializer for &'ser mut FdSerializerImpl<'fds> { Ok(()) } - fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error> + fn serialize_some<T>(self, value: &'fds T) -> Result<Self::Ok, Self::Error> where T: SerializeWithFds + ?Sized, { @@ -489,7 +502,9 @@ pub trait SerializeWithFds { /// /// In most cases, the implementation should largely mirror the /// `serialize` implementation. - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error>; + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>; } macro_rules! serialize_impl { @@ -499,7 +514,10 @@ macro_rules! serialize_impl { Serialize::serialize(&self, serializer) } - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { serializer.serialize_unit() } } @@ -512,8 +530,11 @@ macro_rules! fd_impl_body { serializer.serialize_unit() } - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { - serializer.serialize_raw_fd(AsRawFd::as_raw_fd(self)) + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { + serializer.serialize_raw_fd(self) } }; } @@ -562,14 +583,17 @@ fd_impl!(std::process::ChildStdout); fd_impl!(sys_util::net::UnixSeqpacket); impl<T: SerializeWithFds> SerializeWithFds for Option<T> { - fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { match self { Some(value) => serializer.serialize_some(&SerializeAdapter(value)), None => serializer.serialize_none(), } } - fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { match self { Some(value) => serializer.serialize_some(value), None => serializer.serialize_none(), @@ -586,7 +610,10 @@ impl<T: SerializeWithFds> SerializeWithFds for Vec<T> { seq.end() } - fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { let mut seq = serializer.serialize_seq(Some(self.len()))?; for element in self { seq.serialize_element(element)?; @@ -596,7 +623,7 @@ impl<T: SerializeWithFds> SerializeWithFds for Vec<T> { } impl<K: SerializeWithFds, V: SerializeWithFds> SerializeWithFds for BTreeMap<K, V> { - fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { let mut map = serializer.serialize_map(Some(self.len()))?; for (k, v) in self { map.serialize_entry(&SerializeAdapter(k), &SerializeAdapter(v))?; @@ -604,7 +631,10 @@ impl<K: SerializeWithFds, V: SerializeWithFds> SerializeWithFds for BTreeMap<K, map.end() } - fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize_fds<'fds, S>(&'fds self, serializer: S) -> Result<S::Ok, S::Error> + where + S: FdSerializer<'fds>, + { let mut map = serializer.serialize_map(Some(self.len()))?; for (k, v) in self { map.serialize_entry(k, v)?; @@ -613,36 +643,6 @@ impl<K: SerializeWithFds, V: SerializeWithFds> SerializeWithFds for BTreeMap<K, } } -/// A convenience for serializing a `RawFd` field in a composite type, -/// using a `serialize_raw_fd` method. -/// -/// `RawFd` cannot directly implement `SerializeWithFds` because it is -/// a type alias to `i16`, and so it would be ambiguous whether to -/// serialize an `i16` value as an integer or as a file descriptor. -/// -/// In general, using `RawFd` should be avoided in favour of types -/// like `File` that represent fd ownership. `SerializeRawFd` borrows -/// a `RawFd` even though `RawFd` is `Copy` to prevent its misuse as a -/// generic container for a `RawFd`, which it is not. -#[derive(Debug)] -pub struct SerializeRawFd<'a>(&'a RawFd); - -impl<'a> SerializeRawFd<'a> { - pub fn new(fd: &'a RawFd) -> Self { - Self(fd) - } -} - -impl<'a> SerializeWithFds for SerializeRawFd<'a> { - fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { - serializer.serialize_unit() - } - - fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { - serializer.serialize_raw_fd(*self.0) - } -} - /// A wrapper struct to prevent `SerializeWithFds` implementations /// directly implementing `Serialize` for their `serialize` method. /// diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs index bd1bc00..33b10ba 100644 --- a/msg_socket2/src/socket.rs +++ b/msg_socket2/src/socket.rs @@ -9,7 +9,7 @@ use sys_util::{net::UnixSeqpacket, ScmSocket}; use crate::{de, DeserializeWithFds, Error, Fd, FdSerializerImpl, SerializeWithFds}; -/// A Unix **SOCK_SEQPACKET** socket that can send and receive values +/// A Unix SOCK_SEQPACKET socket that can send and receive values /// consisting of binary data and file descriptors. #[derive(Debug)] pub struct Socket<Send, Recv> { @@ -29,14 +29,16 @@ impl<Send, Recv> Socket<Send, Recv> { impl<Send: SerializeWithFds, Recv> Socket<Send, Recv> { pub fn send(&self, value: Send) -> Result<(), Error> { let mut bytes: Vec<u8> = vec![]; - let mut fds: Vec<RawFd> = vec![]; - value.serialize(&mut Serializer::new(&mut bytes, DefaultOptions::new()))?; - value - .serialize_fds(&mut FdSerializerImpl { fds: &mut fds }) - .unwrap(); + let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new()); + let mut fd_serializer = FdSerializerImpl { fds: vec![] }; - self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?; + value.serialize(&mut serializer)?; + value.serialize_fds(&mut fd_serializer).unwrap(); + + let fd_borrows: Vec<_> = fd_serializer.fds.iter().map(|x| x.as_raw_fd()).collect(); + self.sock + .send_with_fds(&[IoSlice::new(&bytes)], &fd_borrows)?; Ok(()) } diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs index a89f414..8bdd1e4 100644 --- a/msg_socket2/tests/round_trip.rs +++ b/msg_socket2/tests/round_trip.rs @@ -15,13 +15,16 @@ use sys_util::net::UnixSeqpacket; struct Inner(File, u16); impl SerializeWithFds for Inner { - fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { let mut state = serializer.serialize_tuple_struct("Inner", 1)?; state.serialize_field(&self.1)?; state.end() } - fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize_fds<'fds, S: FdSerializer<'fds>>( + &'fds self, + serializer: S, + ) -> Result<S::Ok, S::Error> { let mut state = serializer.serialize_tuple_struct("Inner", 1)?; state.serialize_field(&self.0)?; state.end() @@ -71,7 +74,10 @@ impl SerializeWithFds for Test { state.end() } - fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> { + fn serialize_fds<'fds, S: FdSerializer<'fds>>( + &'fds self, + serializer: S, + ) -> Result<S::Ok, S::Error> { let mut state = serializer.serialize_struct("Test", 2)?; state.serialize_field("fd", &self.fd)?; state.serialize_field("inner", &self.inner)?; |