summary refs log tree commit diff
path: root/msg_socket2
diff options
context:
space:
mode:
Diffstat (limited to 'msg_socket2')
-rw-r--r--msg_socket2/Cargo.toml1
-rw-r--r--msg_socket2/derive/Cargo.toml19
-rw-r--r--msg_socket2/derive/lib.rs264
-rw-r--r--msg_socket2/src/lib.rs6
-rw-r--r--msg_socket2/src/ser.rs600
-rw-r--r--msg_socket2/src/socket.rs11
-rw-r--r--msg_socket2/tests/round_trip.rs54
7 files changed, 903 insertions, 52 deletions
diff --git a/msg_socket2/Cargo.toml b/msg_socket2/Cargo.toml
index ba167d8..7a7cf73 100644
--- a/msg_socket2/Cargo.toml
+++ b/msg_socket2/Cargo.toml
@@ -5,6 +5,7 @@ authors = ["Alyssa Ross <hi@alyssa.is>"]
 edition = "2018"
 
 [dependencies]
+msg_socket2_derive = { path = "derive" }
 serde = "1.0.104"
 sys_util = { path = "../sys_util" }
 
diff --git a/msg_socket2/derive/Cargo.toml b/msg_socket2/derive/Cargo.toml
new file mode 100644
index 0000000..19badea
--- /dev/null
+++ b/msg_socket2/derive/Cargo.toml
@@ -0,0 +1,19 @@
+# SPDX-License-Identifier: MIT OR Apache-2.0
+# Copyright 2020, Alyssa Ross
+
+[package]
+name = "msg_socket2_derive"
+version = "0.1.0"
+authors = ["Alyssa Ross <hi@alyssa.is>"]
+license = "MIT OR Apache-2.0"
+edition = "2018"
+description = "Derive macros for msg_socket2"
+
+[dependencies]
+proc-macro2 = "^1"
+quote = "^1"
+syn = "^1"
+
+[lib]
+proc-macro = true
+path = "lib.rs"
diff --git a/msg_socket2/derive/lib.rs b/msg_socket2/derive/lib.rs
new file mode 100644
index 0000000..42c17c3
--- /dev/null
+++ b/msg_socket2/derive/lib.rs
@@ -0,0 +1,264 @@
+// SPDX-License-Identifier: MIT OR Apache-2.0
+// Copyright 2020, Alyssa Ross
+
+use proc_macro;
+use proc_macro2::{Span, TokenStream};
+use quote::{quote, ToTokens};
+use std::convert::{TryFrom, TryInto};
+use std::fmt::{self, Display, Formatter};
+use syn::spanned::Spanned;
+use syn::{
+    parse_macro_input, Attribute, DeriveInput, GenericParam, Generics, Lifetime, LifetimeDef, Lit,
+    Meta, MetaList, MetaNameValue, NestedMeta,
+};
+
+#[proc_macro_derive(SerializeWithFds, attributes(msg_socket2))]
+pub fn serialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let input = parse_macro_input!(input as DeriveInput);
+    let impl_for_input = impl_serialize_with_fds(input);
+    // println!("{}", impl_for_input);
+    impl_for_input.into()
+}
+
+#[proc_macro_derive(DeserializeWithFds, attributes(msg_socket2))]
+pub fn deserialize_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let input = parse_macro_input!(input as DeriveInput);
+    let impl_for_input = impl_deserialize_with_fds(input);
+    // println!("{}", impl_for_input);
+    impl_for_input.into()
+}
+
+#[derive(Debug)]
+enum Strategy {
+    AsRawFd,
+    Serde,
+}
+
+impl TryFrom<Lit> for Strategy {
+    type Error = Error;
+
+    fn try_from(lit: Lit) -> Result<Self, Self::Error> {
+        match lit {
+            Lit::Str(value) if value.value() == "AsRawFd" => Ok(Strategy::AsRawFd),
+            Lit::Str(value) if value.value() == "serde" => Ok(Strategy::Serde),
+            _ => Err(Error::Str {
+                message: "invalid strategy",
+                span: Some(lit.span()),
+            }),
+        }
+    }
+}
+
+#[derive(Debug)]
+struct Config {
+    strategy: Strategy,
+}
+
+#[derive(Debug)]
+enum Error {
+    Syn(syn::Error),
+    Str {
+        message: &'static str,
+        span: Option<Span>,
+    },
+}
+
+impl Display for Error {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        use Error::*;
+        match self {
+            Syn(error) => write!(f, "{}", error),
+            Str { message, .. } => write!(f, "{}", message),
+        }
+    }
+}
+
+impl Spanned for Error {
+    fn span(&self) -> Span {
+        use Error::*;
+        match self {
+            Syn(error) => error.span(),
+            Str { span, .. } => span.unwrap_or(Span::call_site()),
+        }
+    }
+}
+
+impl From<syn::Error> for Error {
+    fn from(error: syn::Error) -> Self {
+        Self::Syn(error)
+    }
+}
+
+fn get_strategy(attrs: &[Attribute]) -> Result<Option<Strategy>, Error> {
+    for attr in attrs {
+        let (attr_path, nested) = match attr.parse_meta()? {
+            Meta::List(MetaList { path, nested, .. }) => (path, nested),
+            _ => continue,
+        };
+
+        match attr_path.get_ident() {
+            Some(ident) if ident == "msg_socket2" => (),
+            _ => continue,
+        }
+
+        for meta in nested {
+            let (name_path, value) = match meta {
+                NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => (path, lit),
+                _ => continue,
+            };
+
+            match name_path.get_ident() {
+                Some(ident) if ident == "strategy" => (),
+                _ => continue,
+            }
+
+            return value.try_into().map(Some);
+        }
+    }
+
+    Ok(None)
+}
+
+fn get_config(attrs: &[Attribute]) -> Result<Config, Error> {
+    let strategy = get_strategy(&attrs)?.ok_or_else(|| Error::Str {
+        message: "missing strategy",
+        span: None,
+    })?;
+
+    Ok(Config { strategy })
+}
+
+fn split_for_deserialize_impl(mut generics: Generics) -> (TokenStream, TokenStream, TokenStream) {
+    let (_, ty_generics, where_clause) = generics.split_for_impl();
+
+    // Convert to token streams to take ownership, so generics can be
+    // mutated to generate impl_generics.
+    let ty_generics = ty_generics.into_token_stream();
+    let where_clause = where_clause.into_token_stream();
+
+    // Add the deserializer lifetime to the list of generics, then
+    // calculate the impl_generics.  Can't do this before calculating
+    // ty_generics because then the deserializer lifetime would be
+    // passed as a lifetime parameter to the type.
+    //
+    // FIXME: there must be a better way of doing this.
+    let lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site()));
+    generics.params.insert(0, GenericParam::Lifetime(lifetime));
+    let (impl_generics, _, _) = generics.split_for_impl();
+
+    (impl_generics.into_token_stream(), ty_generics, where_clause)
+}
+
+fn as_raw_fd_serialize_impl(input: DeriveInput) -> TokenStream {
+    let name = input.ident;
+    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+
+    quote! {
+        impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause {
+            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
+            where
+                S: ::msg_socket2::Serializer,
+            {
+                serializer.serialize_unit()
+            }
+
+            fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
+            where
+                S: ::msg_socket2::FdSerializer,
+            {
+                serializer.serialize_raw_fd(::std::os::unix::io::AsRawFd::as_raw_fd(self))
+            }
+        }
+    }
+}
+
+fn as_raw_fd_deserialize_impl(input: DeriveInput) -> TokenStream {
+    let name = input.ident;
+    let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics);
+
+    quote! {
+        impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics
+        #where_clause
+        {
+            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
+            where
+                D: ::msg_socket2::DeserializerWithFds<'__de>,
+            {
+                deserializer.deserialize_fd().map(|x| x.specialize())
+            }
+        }
+    }
+}
+
+fn serde_serialize_impl(input: DeriveInput) -> TokenStream {
+    let name = input.ident;
+    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
+
+    quote! {
+        impl #impl_generics ::msg_socket2::SerializeWithFds for #name #ty_generics #where_clause {
+            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
+            where
+                S: ::msg_socket2::Serializer,
+            {
+                ::msg_socket2::Serialize::serialize(&self, serializer)
+            }
+
+            fn serialize_fds<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
+            where
+                S: ::msg_socket2::FdSerializer,
+            {
+                serializer.serialize_unit()
+            }
+        }
+    }
+}
+
+fn serde_deserialize_impl(input: DeriveInput) -> TokenStream {
+    let name = input.ident;
+    let (impl_generics, ty_generics, where_clause) = split_for_deserialize_impl(input.generics);
+
+    quote! {
+        impl #impl_generics ::msg_socket2::DeserializeWithFds<'__de> for #name #ty_generics
+        #where_clause
+        {
+            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
+            where
+                D: ::msg_socket2::DeserializerWithFds<'__de>,
+            {
+                deserializer.invite_deserialize()
+            }
+        }
+    }
+}
+
+fn impl_serialize_with_fds(input: DeriveInput) -> TokenStream {
+    let config = match get_config(&input.attrs) {
+        Ok(config) => config,
+        Err(error) => {
+            // FIXME: Tried using compile_error! but it just got
+            // swallowed.  Is that not possible with derive macros?
+            panic!("{}", error);
+        }
+    };
+
+    match config.strategy {
+        Strategy::AsRawFd => as_raw_fd_serialize_impl(input),
+        Strategy::Serde => serde_serialize_impl(input),
+    }
+}
+
+fn impl_deserialize_with_fds(input: DeriveInput) -> TokenStream {
+    let config = match get_config(&input.attrs) {
+        Ok(config) => config,
+        Err(error) => {
+            // FIXME: Tried using compile_error! but it just got
+            // swallowed.  Is that not possible with derive macros?
+            panic!("{}", error);
+        }
+    };
+
+    match config.strategy {
+        Strategy::AsRawFd => as_raw_fd_deserialize_impl(input),
+        Strategy::Serde => serde_deserialize_impl(input),
+    }
+}
diff --git a/msg_socket2/src/lib.rs b/msg_socket2/src/lib.rs
index a1d8ceb..9bb2526 100644
--- a/msg_socket2/src/lib.rs
+++ b/msg_socket2/src/lib.rs
@@ -25,12 +25,12 @@
 
 mod de;
 mod error;
-mod ser;
+pub mod ser;
 mod socket;
 
-pub(crate) use ser::SerializerWithFdsImpl;
+pub(crate) use ser::FdSerializerImpl;
 
 pub use de::{DeserializeWithFds, DeserializerWithFds};
 pub use error::Error;
-pub use ser::{SerializeWithFds, SerializerWithFds};
+pub use ser::{FdSerializer, Serialize, SerializeWithFds, Serializer};
 pub use socket::Socket;
diff --git a/msg_socket2/src/ser.rs b/msg_socket2/src/ser.rs
index 7bffa22..e2b9900 100644
--- a/msg_socket2/src/ser.rs
+++ b/msg_socket2/src/ser.rs
@@ -1,40 +1,598 @@
-use serde::Serializer;
+// Copyright 2020, Alyssa Ross
+// All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//     * Redistributions of source code must retain the above copyright
+//       notice, this list of conditions and the following disclaimer.
+//     * Redistributions in binary form must reproduce the above copyright
+//       notice, this list of conditions and the following disclaimer in the
+//       documentation and/or other materials provided with the distribution.
+//     * Neither the name of the <organization> nor the
+//       names of its contributors may be used to endorse or promote products
+//       derived from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
+// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Portions Copyright Erick Tryzelaar and David Tolnay,
+// Licensed under either of Apache License, Version 2.0
+// or MIT license at your option.
+
+//! Data structure serialization framework for Unix domain sockets.
+//!
+//! Much like serde::ser, except sending file descriptors is also
+//! supported.
+
+use serde::ser::*;
+use std::collections::BTreeMap;
+use std::fmt::{self, Display, Formatter};
 use std::os::unix::prelude::*;
 
-pub trait SerializerWithFds {
-    type Ser: Serializer;
+pub use serde::ser::{
+    Serialize, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTupleStruct,
+    SerializeTupleVariant, Serializer,
+};
+
+#[derive(Debug)]
+pub(crate) enum Never {}
+
+impl Display for Never {
+    fn fmt(&self, _: &mut Formatter) -> fmt::Result {
+        unreachable!()
+    }
+}
+
+impl StdError for Never {}
+
+impl Error for Never {
+    fn custom<T>(_: T) -> Self {
+        unreachable!()
+    }
+}
+
+#[derive(Debug)]
+pub(crate) struct Composite<'ser, S> {
+    serializer: &'ser mut S,
+}
+
+/// Returned from `SerializerWithFds::serialize_seq`.
+pub trait SerializeSeqFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
+}
+
+impl<'ser, 'fds> SerializeSeqFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
 
-    fn serializer(self) -> Self::Ser;
-    fn fds(&mut self) -> &mut Vec<RawFd>;
+    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// Returned from `SerializerWithFds::serialize_tuple_struct`.
+pub trait SerializeTupleStructFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
+}
+
+impl<'ser, 'fds> SerializeTupleStructFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
+
+    fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// Returned from `SerializerWithFds::serialize_tuple_variant`.
+pub trait SerializeTupleVariantFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
 }
 
+impl<'ser, 'fds> SerializeTupleVariantFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
+
+    fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// Returned from `SerializerWithFds::serialize_map`.
+pub trait SerializeMapFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error>;
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn serialize_entry<K: SerializeWithFds + ?Sized, V: SerializeWithFds + ?Sized>(
+        &mut self,
+        key: &K,
+        value: &V,
+    ) -> Result<(), Self::Error> {
+        self.serialize_key(key)?;
+        self.serialize_value(value)
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
+}
+
+impl<'ser, 'fds> SerializeMapFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
+
+    fn serialize_key<T: SerializeWithFds + ?Sized>(&mut self, key: &T) -> Result<(), Self::Error> {
+        key.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn serialize_value<T>(&mut self, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// Returned from `SerializerWithFds::serialize_struct`.
+pub trait SerializeStructFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
+        let _ = key;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
+}
+
+impl<'ser, 'fds> SerializeStructFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// Returned from `SerializerWithFds::serialize_struct_variant`.
+pub trait SerializeStructVariantFds {
+    type Ok;
+    type Error: Error;
+
+    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> {
+        let _ = key;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error>;
+}
+
+impl<'ser, 'fds> SerializeStructVariantFds for Composite<'ser, FdSerializerImpl<'fds>> {
+    type Ok = ();
+    type Error = <&'ser mut FdSerializerImpl<'fds> as FdSerializer>::Error;
+
+    fn serialize_field<T>(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(&mut *self.serializer)?;
+        Ok(())
+    }
+
+    fn end(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+}
+
+/// A **data format** that can send file descriptors between Unix processes.
+pub trait FdSerializer: Sized {
+    type Ok;
+    type Error: Error;
+
+    type SerializeSeqFds: SerializeSeqFds<Ok = Self::Ok, Error = Self::Error>;
+    type SerializeTupleStructFds: SerializeTupleStructFds<Ok = Self::Ok, Error = Self::Error>;
+    type SerializeTupleVariantFds: SerializeTupleVariantFds<Ok = Self::Ok, Error = Self::Error>;
+    type SerializeMapFds: SerializeMapFds<Ok = Self::Ok, Error = Self::Error>;
+    type SerializeStructFds: SerializeStructFds<Ok = Self::Ok, Error = Self::Error>;
+    type SerializeStructVariantFds: SerializeStructVariantFds<Ok = Self::Ok, Error = Self::Error>;
+
+    fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error>;
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error>;
+
+    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: SerializeWithFds + ?Sized;
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error>;
+
+    fn serialize_unit_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error>;
+
+    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error>;
+
+    fn serialize_tuple_struct(
+        self,
+        name: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleStructFds, Self::Error>;
+
+    fn serialize_tuple_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeTupleVariantFds, Self::Error>;
+
+    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMapFds, Self::Error>;
+
+    fn serialize_struct(
+        self,
+        name: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeStructFds, Self::Error>;
+
+    fn serialize_struct_variant(
+        self,
+        name: &'static str,
+        variant_index: u32,
+        variant: &'static str,
+        len: usize,
+    ) -> Result<Self::SerializeStructVariantFds, Self::Error>;
+}
+
+#[derive(Debug)]
+pub(crate) struct FdSerializerImpl<'fds> {
+    pub fds: &'fds mut Vec<RawFd>,
+}
+
+impl<'ser, 'fds> FdSerializer for &'ser mut FdSerializerImpl<'fds> {
+    type Ok = ();
+    type Error = Never;
+
+    type SerializeSeqFds = Composite<'ser, FdSerializerImpl<'fds>>;
+    type SerializeTupleStructFds = Composite<'ser, FdSerializerImpl<'fds>>;
+    type SerializeTupleVariantFds = Composite<'ser, FdSerializerImpl<'fds>>;
+    type SerializeMapFds = Composite<'ser, FdSerializerImpl<'fds>>;
+    type SerializeStructFds = Composite<'ser, FdSerializerImpl<'fds>>;
+    type SerializeStructVariantFds = Composite<'ser, FdSerializerImpl<'fds>>;
+
+    fn serialize_raw_fd(self, value: RawFd) -> Result<Self::Ok, Self::Error> {
+        self.fds.push(value);
+        Ok(())
+    }
+
+    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+
+    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
+    where
+        T: SerializeWithFds + ?Sized,
+    {
+        value.serialize_fds(self)
+    }
+
+    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+
+    fn serialize_unit_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+    ) -> Result<Self::Ok, Self::Error> {
+        Ok(())
+    }
+
+    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+
+    fn serialize_tuple_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleStructFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+
+    fn serialize_tuple_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStructVariantFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+
+    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeSeqFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+
+    fn serialize_struct(
+        self,
+        _name: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeTupleStructFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+
+    fn serialize_struct_variant(
+        self,
+        _name: &'static str,
+        _variant_index: u32,
+        _variant: &'static str,
+        _len: usize,
+    ) -> Result<Self::SerializeStructVariantFds, Self::Error> {
+        Ok(Composite { serializer: self })
+    }
+}
+
+/// A **data structure** that can be sent over a Unix domain socket.
 pub trait SerializeWithFds {
-    fn serialize<Ser: SerializerWithFds>(
-        &self,
-        serializer: Ser,
-    ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error>;
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>;
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>;
+}
+
+macro_rules! serialize_impl {
+    ($type:ty) => {
+        impl SerializeWithFds for $type {
+            fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+                Serialize::serialize(&self, serializer)
+            }
+
+            fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+                serializer.serialize_unit()
+            }
+        }
+    };
+}
+
+macro_rules! fd_impl_body {
+    () => {
+        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+            serializer.serialize_unit()
+        }
+
+        fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+            serializer.serialize_raw_fd(AsRawFd::as_raw_fd(self))
+        }
+    };
+}
+
+// FIXME: it would be nice if these two cases could be consolidated
+// with a better pattern, and fd_impl_body!() could be inlined.
+macro_rules! fd_impl {
+    ($type:ty) => {
+        impl SerializeWithFds for $type {
+            fd_impl_body!();
+        }
+    };
+
+    ($type:ty, $($args:tt),+) => {
+        impl<$($args),*> SerializeWithFds for $type {
+            fd_impl_body!();
+        }
+    };
 }
 
+serialize_impl!(u8);
+serialize_impl!(u16);
+serialize_impl!(u32);
+serialize_impl!(u64);
+serialize_impl!(usize);
+
+serialize_impl!(String);
+serialize_impl!(std::path::PathBuf);
+
+fd_impl!(std::fs::File);
+fd_impl!(std::io::Stderr);
+fd_impl!(std::io::StderrLock<'a>, 'a);
+fd_impl!(std::io::Stdin);
+fd_impl!(std::io::StdinLock<'a>, 'a);
+fd_impl!(std::io::Stdout);
+fd_impl!(std::io::StdoutLock<'a>, 'a);
+fd_impl!(std::net::TcpListener);
+fd_impl!(std::net::TcpStream);
+fd_impl!(std::net::UdpSocket);
+fd_impl!(std::os::unix::net::UnixDatagram);
+fd_impl!(std::os::unix::net::UnixListener);
+fd_impl!(std::os::unix::net::UnixStream);
+fd_impl!(std::process::ChildStderr);
+fd_impl!(std::process::ChildStdin);
+fd_impl!(std::process::ChildStdout);
+
+impl<T: SerializeWithFds> SerializeWithFds for Option<T> {
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        match self {
+            Some(value) => serializer.serialize_some(&SerializeAdapter(value)),
+            None => serializer.serialize_none(),
+        }
+    }
+
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        match self {
+            Some(value) => serializer.serialize_some(value),
+            None => serializer.serialize_none(),
+        }
+    }
+}
+
+impl<T: SerializeWithFds> SerializeWithFds for Vec<T> {
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut seq = serializer.serialize_seq(Some(self.len()))?;
+        for element in self {
+            seq.serialize_element(&SerializeAdapter(element))?;
+        }
+        seq.end()
+    }
+
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut seq = serializer.serialize_seq(Some(self.len()))?;
+        for element in self {
+            seq.serialize_element(element)?;
+        }
+        seq.end()
+    }
+}
+
+impl<K: SerializeWithFds, V: SerializeWithFds> SerializeWithFds for BTreeMap<K, V> {
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut map = serializer.serialize_map(Some(self.len()))?;
+        for (k, v) in self {
+            map.serialize_entry(&SerializeAdapter(k), &SerializeAdapter(v))?;
+        }
+        map.end()
+    }
+
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut map = serializer.serialize_map(Some(self.len()))?;
+        for (k, v) in self {
+            map.serialize_entry(k, v)?;
+        }
+        map.end()
+    }
+}
+
+/// A convenience for serializing a `RawFd` field in a composite type,
+/// using a `serialize_raw_fd` method.
+///
+/// `RawFd` cannot directly implement `SerializeWithFds` because it is
+/// a type alias to `i16`, and so it would be ambiguous whether to
+/// serialize an `i16` value as an integer or as a file descriptor.
+///
+/// In general, using `RawFd` should be avoided in favour of types
+/// like `File` that represent fd ownership.  `SerializeRawFd` borrows
+/// a `RawFd` even though `RawFd` is `Copy` to prevent its misuse as a
+/// generic container for a `RawFd`, which it is not.
 #[derive(Debug)]
-pub struct SerializerWithFdsImpl<'fds, Ser> {
-    serializer: Ser,
-    fds: &'fds mut Vec<RawFd>,
+pub struct SerializeRawFd<'a>(&'a RawFd);
+
+impl<'a> SerializeRawFd<'a> {
+    pub fn new(fd: &'a RawFd) -> Self {
+        Self(fd)
+    }
 }
 
-impl<'fds, Ser> SerializerWithFdsImpl<'fds, Ser> {
-    pub fn new(fds: &'fds mut Vec<RawFd>, serializer: Ser) -> Self {
-        Self { serializer, fds }
+impl<'a> SerializeWithFds for SerializeRawFd<'a> {
+    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        serializer.serialize_unit()
+    }
+
+    fn serialize_fds<S: FdSerializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        serializer.serialize_raw_fd(*self.0)
     }
 }
 
-impl<'fds, Ser: Serializer> SerializerWithFds for SerializerWithFdsImpl<'fds, Ser> {
-    type Ser = Ser;
+/// A wrapper struct to prevent `SerializeWithFds` implementations
+/// directly implementing `Serialize` for their `serialize` method.
+///
+/// Just doing `Serialize` on a `SerializeWithFds` implmentation would
+/// return a partially serialized value, because the file descriptors
+/// won't have been serialized.
+///
+/// This should therefore only be used to implement
+/// `SerializeWithFds`, where the file descriptors are then seperately
+/// serialized in `serialize_fds`.
+#[derive(Debug)]
+pub struct SerializeAdapter<'a, T: ?Sized>(&'a T);
 
-    fn serializer(self) -> Self::Ser {
-        self.serializer
+impl<'a, T: ?Sized> SerializeAdapter<'a, T> {
+    pub fn new(value: &'a T) -> Self {
+        Self(value)
     }
+}
 
-    fn fds(&mut self) -> &mut Vec<RawFd> {
-        &mut self.fds
+impl<'a, T: SerializeWithFds + ?Sized> Serialize for SerializeAdapter<'a, T> {
+    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        self.0.serialize(serializer)
     }
 }
diff --git a/msg_socket2/src/socket.rs b/msg_socket2/src/socket.rs
index 4e75e82..b836365 100644
--- a/msg_socket2/src/socket.rs
+++ b/msg_socket2/src/socket.rs
@@ -4,9 +4,7 @@ use std::marker::PhantomData;
 use std::os::unix::prelude::*;
 use sys_util::{net::UnixSeqpacket, ScmSocket};
 
-use crate::{
-    DeserializeWithFds, DeserializerWithFds, Error, SerializeWithFds, SerializerWithFdsImpl,
-};
+use crate::{DeserializeWithFds, DeserializerWithFds, Error, FdSerializerImpl, SerializeWithFds};
 
 #[derive(Debug)]
 pub struct Socket<Send, Recv> {
@@ -28,9 +26,10 @@ impl<Send: SerializeWithFds, Recv> Socket<Send, Recv> {
         let mut bytes: Vec<u8> = vec![];
         let mut fds: Vec<RawFd> = vec![];
 
-        let mut serializer = Serializer::new(&mut bytes, DefaultOptions::new());
-        let serializer_with_fds = SerializerWithFdsImpl::new(&mut fds, &mut serializer);
-        value.serialize(serializer_with_fds)?;
+        value.serialize(&mut Serializer::new(&mut bytes, DefaultOptions::new()))?;
+        value
+            .serialize_fds(&mut FdSerializerImpl { fds: &mut fds })
+            .unwrap();
 
         self.sock.send_with_fds(&[IoSlice::new(&bytes)], &fds)?;
 
diff --git a/msg_socket2/tests/round_trip.rs b/msg_socket2/tests/round_trip.rs
index efec94b..1bb6636 100644
--- a/msg_socket2/tests/round_trip.rs
+++ b/msg_socket2/tests/round_trip.rs
@@ -3,14 +3,33 @@ use std::os::unix::prelude::*;
 use std::fmt::{self, Formatter};
 use std::marker::PhantomData;
 
-use msg_socket2::*;
-use serde::de::*;
-use serde::ser::*;
+use msg_socket2::{
+    ser::{
+        SerializeAdapter, SerializeRawFd, SerializeStruct, SerializeStructFds,
+        SerializeTupleStruct, SerializeTupleStructFds,
+    },
+    DeserializeWithFds, DeserializerWithFds, FdSerializer, SerializeWithFds, Serializer, Socket,
+};
+use serde::de::{Deserializer, SeqAccess};
 use sys_util::net::UnixSeqpacket;
 
 #[derive(Debug)]
 struct Inner(RawFd, u16);
 
+impl SerializeWithFds for Inner {
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut state = serializer.serialize_tuple_struct("Inner", 1)?;
+        state.serialize_field(&self.1)?;
+        state.end()
+    }
+
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut state = serializer.serialize_tuple_struct("Inner", 1)?;
+        state.serialize_field(&SerializeRawFd::new(&self.0))?;
+        state.end()
+    }
+}
+
 #[derive(Debug)]
 struct Test {
     fd: RawFd,
@@ -18,26 +37,17 @@ struct Test {
 }
 
 impl SerializeWithFds for Test {
-    fn serialize<Ser: SerializerWithFds>(
-        &self,
-        mut serializer: Ser,
-    ) -> Result<<Ser::Ser as Serializer>::Ok, <Ser::Ser as Serializer>::Error> {
-        serializer.fds().push(self.fd);
-        serializer.fds().push(self.inner.0);
-
-        struct SerializableInner<'a>(&'a Inner);
-
-        impl<'a> Serialize for SerializableInner<'a> {
-            fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
-                let mut state = serializer.serialize_tuple_struct("Inner", 1)?;
-                state.serialize_field(&(self.0).1)?;
-                state.end()
-            }
-        }
-
-        let mut state = serializer.serializer().serialize_struct("Test", 1)?;
+    fn serialize<Ser: Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut state = serializer.serialize_struct("Test", 1)?;
         state.skip_field("fd")?;
-        state.serialize_field("inner", &SerializableInner(&self.inner))?;
+        state.serialize_field("inner", &SerializeAdapter::new(&self.inner))?;
+        state.end()
+    }
+
+    fn serialize_fds<Ser: FdSerializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
+        let mut state = serializer.serialize_struct("Test", 2)?;
+        state.serialize_field("fd", &SerializeRawFd::new(&self.fd))?;
+        state.serialize_field("inner", &self.inner)?;
         state.end()
     }
 }