summary refs log blame commit diff
path: root/p9/wire_format_derive/wire_format_derive.rs
blob: 2d1c3640d2cdcdde0dce41eeed05327393ffab14 (plain) (tree)
1
2
3
4
5
6
7
8
9
10









                                                                             
                                     
                                  
                          
                                                               


                                                        


                                                                                  

 
                                                            





                                                                                       
                                
 


                                                                  
 

                                                                    

                    

                                        
 
                                  
 
                                            
 

                                            


                                   
                                                                                   











                                                                               
                                              

                                                     
                                                      




                                                       













                                                                            
                                                   

                                                     
                                                      




                                                               















                                                                            
                                                                      

                                                     
                                                      




                                                              


                                                       



                                     




                           
                               













                                
                         










                                               






                                                               
                                                                                 











                                               







                                                                 

                                                        












                                               
                                                              











                                                                

                                                                    


















                                               

                                            
 

                                          
                                                












































                                                                                       



                                                    

     
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

//! Derives a 9P wire format encoding for a struct by recursively calling
//! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
//! This is only intended to be used from within the `p9` crate.

#![recursion_limit = "256"]

use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::spanned::Spanned;
use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident};

/// The function that derives the actual implementation.
#[proc_macro_derive(P9WireFormat)]
pub fn p9_wire_format(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    p9_wire_format_inner(input).into()
}

fn p9_wire_format_inner(input: DeriveInput) -> TokenStream {
    if !input.generics.params.is_empty() {
        return quote! {
            compile_error!("derive(P9WireFormat) does not support generic parameters");
        };
    }

    let container = input.ident;

    let byte_size_impl = byte_size_sum(&input.data);
    let encode_impl = encode_wire_format(&input.data);
    let decode_impl = decode_wire_format(&input.data, &container);

    let scope = format!("wire_format_{}", container).to_lowercase();
    let scope = Ident::new(&scope, Span::call_site());
    quote! {
        mod #scope {
            use std::io;
            use std::result::Result::Ok;

            use super::#container;

            use crate::protocol::WireFormat;

            impl WireFormat for #container {
                fn byte_size(&self) -> u32 {
                    #byte_size_impl
                }

                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
                    #encode_impl
                }

                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
                    #decode_impl
                }
            }
        }
    }
}

// Generate code that recursively calls byte_size on every field in the struct.
fn byte_size_sum(data: &Data) -> TokenStream {
    if let Data::Struct(data) = data {
        if let Fields::Named(fields) = &data.fields {
            let fields = fields.named.iter().map(|f| {
                let field = &f.ident;
                let span = field.span();
                quote_spanned! {span=>
                    WireFormat::byte_size(&self.#field)
                }
            });

            quote! {
                0 #(+ #fields)*
            }
        } else {
            unimplemented!();
        }
    } else {
        unimplemented!();
    }
}

// Generate code that recursively calls encode on every field in the struct.
fn encode_wire_format(data: &Data) -> TokenStream {
    if let Data::Struct(data) = data {
        if let Fields::Named(fields) = &data.fields {
            let fields = fields.named.iter().map(|f| {
                let field = &f.ident;
                let span = field.span();
                quote_spanned! {span=>
                    WireFormat::encode(&self.#field, _writer)?;
                }
            });

            quote! {
                #(#fields)*

                Ok(())
            }
        } else {
            unimplemented!();
        }
    } else {
        unimplemented!();
    }
}

// Generate code that recursively calls decode on every field in the struct.
fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
    if let Data::Struct(data) = data {
        if let Fields::Named(fields) = &data.fields {
            let values = fields.named.iter().map(|f| {
                let field = &f.ident;
                let span = field.span();
                quote_spanned! {span=>
                    let #field = WireFormat::decode(_reader)?;
                }
            });

            let members = fields.named.iter().map(|f| {
                let field = &f.ident;
                quote! {
                    #field: #field,
                }
            });

            quote! {
                #(#values)*

                Ok(#container {
                    #(#members)*
                })
            }
        } else {
            unimplemented!();
        }
    } else {
        unimplemented!();
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use syn::parse_quote;

    #[test]
    fn byte_size() {
        let input: DeriveInput = parse_quote! {
            struct Item {
                ident: u32,
                with_underscores: String,
                other: u8,
            }
        };

        let expected = quote! {
            0
                + WireFormat::byte_size(&self.ident)
                + WireFormat::byte_size(&self.with_underscores)
                + WireFormat::byte_size(&self.other)
        };

        assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
    }

    #[test]
    fn encode() {
        let input: DeriveInput = parse_quote! {
            struct Item {
                ident: u32,
                with_underscores: String,
                other: u8,
            }
        };

        let expected = quote! {
            WireFormat::encode(&self.ident, _writer)?;
            WireFormat::encode(&self.with_underscores, _writer)?;
            WireFormat::encode(&self.other, _writer)?;
            Ok(())
        };

        assert_eq!(
            encode_wire_format(&input.data).to_string(),
            expected.to_string(),
        );
    }

    #[test]
    fn decode() {
        let input: DeriveInput = parse_quote! {
            struct Item {
                ident: u32,
                with_underscores: String,
                other: u8,
            }
        };

        let container = Ident::new("Item", Span::call_site());
        let expected = quote! {
            let ident = WireFormat::decode(_reader)?;
            let with_underscores = WireFormat::decode(_reader)?;
            let other = WireFormat::decode(_reader)?;
            Ok(Item {
                ident: ident,
                with_underscores: with_underscores,
                other: other,
            })
        };

        assert_eq!(
            decode_wire_format(&input.data, &container).to_string(),
            expected.to_string(),
        );
    }

    #[test]
    fn end_to_end() {
        let input: DeriveInput = parse_quote! {
            struct Niijima_先輩 {
                a: u8,
                b: u16,
                c: u32,
                d: u64,
                e: String,
                f: Vec<String>,
                g: Nested,
            }
        };

        let expected = quote! {
            mod wire_format_niijima_先輩 {
                use std::io;
                use std::result::Result::Ok;

                use super::Niijima_先輩;

                use crate::protocol::WireFormat;

                impl WireFormat for Niijima_先輩 {
                    fn byte_size(&self) -> u32 {
                        0
                        + WireFormat::byte_size(&self.a)
                        + WireFormat::byte_size(&self.b)
                        + WireFormat::byte_size(&self.c)
                        + WireFormat::byte_size(&self.d)
                        + WireFormat::byte_size(&self.e)
                        + WireFormat::byte_size(&self.f)
                        + WireFormat::byte_size(&self.g)
                    }

                    fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
                        WireFormat::encode(&self.a, _writer)?;
                        WireFormat::encode(&self.b, _writer)?;
                        WireFormat::encode(&self.c, _writer)?;
                        WireFormat::encode(&self.d, _writer)?;
                        WireFormat::encode(&self.e, _writer)?;
                        WireFormat::encode(&self.f, _writer)?;
                        WireFormat::encode(&self.g, _writer)?;
                        Ok(())
                    }
                    fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
                        let a = WireFormat::decode(_reader)?;
                        let b = WireFormat::decode(_reader)?;
                        let c = WireFormat::decode(_reader)?;
                        let d = WireFormat::decode(_reader)?;
                        let e = WireFormat::decode(_reader)?;
                        let f = WireFormat::decode(_reader)?;
                        let g = WireFormat::decode(_reader)?;
                        Ok(Niijima_先輩 {
                            a: a,
                            b: b,
                            c: c,
                            d: d,
                            e: e,
                            f: f,
                            g: g,
                        })
                    }
                }
            }
        };

        assert_eq!(
            p9_wire_format_inner(input).to_string(),
            expected.to_string(),
        );
    }
}