// Copyright 2018 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //! Derives a 9P wire format encoding for a struct by recursively calling //! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct. //! This is only intended to be used from within the `p9` crate. #![recursion_limit = "256"] use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident}; /// The function that derives the actual implementation. #[proc_macro_derive(P9WireFormat)] pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); p9_wire_format_inner(input).into() } fn p9_wire_format_inner(input: DeriveInput) -> TokenStream { if !input.generics.params.is_empty() { return quote! { compile_error!("derive(P9WireFormat) does not support generic parameters"); }; } let container = input.ident; let byte_size_impl = byte_size_sum(&input.data); let encode_impl = encode_wire_format(&input.data); let decode_impl = decode_wire_format(&input.data, &container); let scope = format!("wire_format_{}", container).to_lowercase(); let scope = Ident::new(&scope, Span::call_site()); quote! { mod #scope { use std::io; use std::result::Result::Ok; use super::#container; use crate::protocol::WireFormat; impl WireFormat for #container { fn byte_size(&self) -> u32 { #byte_size_impl } fn encode(&self, _writer: &mut W) -> io::Result<()> { #encode_impl } fn decode(_reader: &mut R) -> io::Result { #decode_impl } } } } } // Generate code that recursively calls byte_size on every field in the struct. fn byte_size_sum(data: &Data) -> TokenStream { if let Data::Struct(data) = data { if let Fields::Named(fields) = &data.fields { let fields = fields.named.iter().map(|f| { let field = &f.ident; let span = field.span(); quote_spanned! {span=> WireFormat::byte_size(&self.#field) } }); quote! { 0 #(+ #fields)* } } else { unimplemented!(); } } else { unimplemented!(); } } // Generate code that recursively calls encode on every field in the struct. fn encode_wire_format(data: &Data) -> TokenStream { if let Data::Struct(data) = data { if let Fields::Named(fields) = &data.fields { let fields = fields.named.iter().map(|f| { let field = &f.ident; let span = field.span(); quote_spanned! {span=> WireFormat::encode(&self.#field, _writer)?; } }); quote! { #(#fields)* Ok(()) } } else { unimplemented!(); } } else { unimplemented!(); } } // Generate code that recursively calls decode on every field in the struct. fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream { if let Data::Struct(data) = data { if let Fields::Named(fields) = &data.fields { let values = fields.named.iter().map(|f| { let field = &f.ident; let span = field.span(); quote_spanned! {span=> let #field = WireFormat::decode(_reader)?; } }); let members = fields.named.iter().map(|f| { let field = &f.ident; quote! { #field: #field, } }); quote! { #(#values)* Ok(#container { #(#members)* }) } } else { unimplemented!(); } } else { unimplemented!(); } } #[cfg(test)] mod tests { use super::*; use syn::parse_quote; #[test] fn byte_size() { let input: DeriveInput = parse_quote! { struct Item { ident: u32, with_underscores: String, other: u8, } }; let expected = quote! { 0 + WireFormat::byte_size(&self.ident) + WireFormat::byte_size(&self.with_underscores) + WireFormat::byte_size(&self.other) }; assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string()); } #[test] fn encode() { let input: DeriveInput = parse_quote! { struct Item { ident: u32, with_underscores: String, other: u8, } }; let expected = quote! { WireFormat::encode(&self.ident, _writer)?; WireFormat::encode(&self.with_underscores, _writer)?; WireFormat::encode(&self.other, _writer)?; Ok(()) }; assert_eq!( encode_wire_format(&input.data).to_string(), expected.to_string(), ); } #[test] fn decode() { let input: DeriveInput = parse_quote! { struct Item { ident: u32, with_underscores: String, other: u8, } }; let container = Ident::new("Item", Span::call_site()); let expected = quote! { let ident = WireFormat::decode(_reader)?; let with_underscores = WireFormat::decode(_reader)?; let other = WireFormat::decode(_reader)?; Ok(Item { ident: ident, with_underscores: with_underscores, other: other, }) }; assert_eq!( decode_wire_format(&input.data, &container).to_string(), expected.to_string(), ); } #[test] fn end_to_end() { let input: DeriveInput = parse_quote! { struct Niijima_先輩 { a: u8, b: u16, c: u32, d: u64, e: String, f: Vec, g: Nested, } }; let expected = quote! { mod wire_format_niijima_先輩 { use std::io; use std::result::Result::Ok; use super::Niijima_先輩; use crate::protocol::WireFormat; impl WireFormat for Niijima_先輩 { fn byte_size(&self) -> u32 { 0 + WireFormat::byte_size(&self.a) + WireFormat::byte_size(&self.b) + WireFormat::byte_size(&self.c) + WireFormat::byte_size(&self.d) + WireFormat::byte_size(&self.e) + WireFormat::byte_size(&self.f) + WireFormat::byte_size(&self.g) } fn encode(&self, _writer: &mut W) -> io::Result<()> { WireFormat::encode(&self.a, _writer)?; WireFormat::encode(&self.b, _writer)?; WireFormat::encode(&self.c, _writer)?; WireFormat::encode(&self.d, _writer)?; WireFormat::encode(&self.e, _writer)?; WireFormat::encode(&self.f, _writer)?; WireFormat::encode(&self.g, _writer)?; Ok(()) } fn decode(_reader: &mut R) -> io::Result { let a = WireFormat::decode(_reader)?; let b = WireFormat::decode(_reader)?; let c = WireFormat::decode(_reader)?; let d = WireFormat::decode(_reader)?; let e = WireFormat::decode(_reader)?; let f = WireFormat::decode(_reader)?; let g = WireFormat::decode(_reader)?; Ok(Niijima_先輩 { a: a, b: b, c: c, d: d, e: e, f: f, g: g, }) } } } }; assert_eq!( p9_wire_format_inner(input).to_string(), expected.to_string(), ); } }