summary refs log tree commit diff
path: root/bit_field/bit_field_derive/bit_field_derive.rs
diff options
context:
space:
mode:
Diffstat (limited to 'bit_field/bit_field_derive/bit_field_derive.rs')
-rw-r--r--bit_field/bit_field_derive/bit_field_derive.rs507
1 files changed, 507 insertions, 0 deletions
diff --git a/bit_field/bit_field_derive/bit_field_derive.rs b/bit_field/bit_field_derive/bit_field_derive.rs
new file mode 100644
index 0000000..1f2290d
--- /dev/null
+++ b/bit_field/bit_field_derive/bit_field_derive.rs
@@ -0,0 +1,507 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#![recursion_limit = "256"]
+extern crate proc_macro;
+extern crate proc_macro2;
+
+#[macro_use]
+extern crate quote;
+
+#[cfg_attr(test, macro_use)]
+extern crate syn;
+
+use std::string::String;
+use std::vec::Vec;
+
+use proc_macro::TokenStream;
+use proc_macro2::{Span, TokenNode};
+use quote::Tokens;
+use syn::{Attribute, Data, DeriveInput, Fields, Ident};
+
+type Result<T> = std::result::Result<T, String>;
+
+/// The function that derives the actual implementation.
+#[proc_macro_derive(BitField, attributes(passthrough))]
+pub fn bitfield(input: TokenStream) -> TokenStream {
+    bitfield_impl(syn::parse(input).unwrap()).into()
+}
+
+fn bitfield_impl(ast: DeriveInput) -> Tokens {
+    if !ast.generics.params.is_empty() {
+        return quote! {
+            compile_error!("derive(BitField) does not support generic parameters");
+        };
+    }
+
+    let name = ast.ident.to_string();
+    let struct_name = match get_struct_name(name.as_str()) {
+        Some(name) => name,
+        None => {
+            return quote! {
+                compile_error!("Check schema name, it should end with Schema.");
+            };
+        }
+    };
+
+    let ident = Ident::new(struct_name, Span::call_site());
+    let test_mod_ident = Ident::new(format!("test_{}", struct_name).as_str(), Span::call_site());
+    // Name of the struct
+    let name = quote!(#ident);
+    let vis = ast.vis.clone();
+    let attrs = match get_attrs(ast.attrs.as_slice()) {
+        Ok(attrs) => attrs,
+        Err(err_str) => {
+            return quote! {
+                compile_error!(#err_str);
+            };
+        }
+    };
+    // Visibility.
+    let vis = quote!(#vis);
+    let fields = match get_struct_fields(ast) {
+        Ok(f) => f,
+        Err(err_str) => {
+            return quote! {
+                compile_error!(#err_str);
+            }
+        }
+    };
+    let struct_def = get_struct_def(&vis, &name, fields.as_slice());
+    let bits_impl = get_bits_impl(&name);
+    let fields_impl = get_fields_impl(fields.as_slice());
+    let tests_impl = get_tests_impl(&name, fields.as_slice());
+    quote!(
+        #(#attrs)*
+        #struct_def
+        #bits_impl
+        impl #name {
+            #(#fields_impl)*
+        }
+        #[cfg(test)]
+        mod #test_mod_ident {
+            use super::*;
+            #(#tests_impl)*
+        }
+        )
+}
+
+// Generate struct name from schema_name. "MyTypeSchema" -> "MyType".
+fn get_struct_name(schema_name: &str) -> Option<&str> {
+    let struct_name = schema_name.trim_right_matches("Schema");
+    if struct_name.len() + "Schema".len() != schema_name.len() {
+        return None;
+    }
+    if struct_name.len() == 0 {
+        return None;
+    }
+    Some(struct_name)
+}
+
+// Unwrap ast to get the named fields. Anything unexpected will be treated as an
+// error.
+// We only care about field names and types.
+// "myfield : BitField3" -> ("myfield", Token(BitField3))
+fn get_struct_fields(ast: DeriveInput) -> Result<Vec<(String, Tokens)>> {
+    let fields = match ast.data {
+        Data::Struct(data_struct) => match data_struct.fields {
+            Fields::Named(fields_named) => fields_named.named,
+            _ => {
+                return Err(format!("Schema must have named fields."));
+            }
+        },
+        _ => {
+            return Err(format!("Schema must be a struct."));
+        }
+    };
+    let mut vec = Vec::new();
+    for field in fields {
+        let ident = match field.ident {
+            Some(ident) => ident,
+            None => {
+                return Err(format!(
+                    "Unknown Error. bit_field_derive library might have a bug."
+                ))
+            }
+        };
+        let ty = field.ty;
+        vec.push((ident.to_string(), quote!(#ty)));
+    }
+
+    Ok(vec)
+}
+
+// We only support attributes of the form #[passthrough(derive(Clone))].
+// Any other attribute will cause undefined behavior.
+fn get_attrs(attrs: &[Attribute]) -> Result<Vec<Tokens>> {
+    let mut attr_tokens = Vec::new();
+    for a in attrs {
+        let tts = a.tts.clone();
+        for t in tts {
+            match t.kind {
+                TokenNode::Group(_d, g) => attr_tokens.push(quote!(#[#g])),
+                _ => return Err(format!("Unsupported attribute.")),
+            }
+        }
+    }
+    Ok(attr_tokens)
+}
+
+fn get_struct_def(vis: &Tokens, name: &Tokens, fields: &[(String, Tokens)]) -> Tokens {
+    let mut field_types = Vec::new();
+    for &(ref _name, ref ty) in fields {
+        field_types.push(ty.clone());
+    }
+
+    // It will be something like:
+    // "(BitField1::FIELD_WIDTH + BitField3::FIELD_WIDTH + BitField4::FIELD_WIDTH) / 8)"
+    let data_size_in_bytes = quote!(
+            ( #( #field_types::FIELD_WIDTH as usize )+* ) / 8
+        );
+    quote! (
+        #vis struct #name {
+            data: [u8; #data_size_in_bytes],
+        }
+        impl #name {
+            pub fn new() -> #name {
+                #name { data: [0; #data_size_in_bytes], }
+            }
+        }
+    )
+}
+
+// Implement setter and getter for all fields.
+fn get_fields_impl(fields: &[(String, Tokens)]) -> Vec<Tokens> {
+    let mut impls = Vec::new();
+    // This vec keeps track of types before this field, used to generate the offset.
+    let mut current_types = vec![quote!(BitField0)];
+
+    for &(ref name, ref ty) in fields {
+        // Creating two copies of current types. As they are going to be moved in quote!.
+        let ct0 = current_types.clone();
+        let ct1 = current_types.clone();
+        let getter_ident = Ident::new(format!("get_{}", name).as_str(), Span::call_site());
+        let setter_ident = Ident::new(format!("set_{}", name).as_str(), Span::call_site());
+        impls.push(quote!(
+                pub fn #getter_ident(&self) -> <#ty as BitFieldSpecifier>::DefaultFieldType {
+                    let offset = #(#ct0::FIELD_WIDTH as usize)+*;
+                    return self.get(offset, #ty::FIELD_WIDTH) as
+                        <#ty as BitFieldSpecifier>::DefaultFieldType;
+                }
+
+                pub fn #setter_ident(&mut self, val: <#ty as BitFieldSpecifier>::DefaultFieldType) {
+                    debug_assert!((val as u64) <= #ty::FIELD_MAX);
+                    let offset = #(#ct1::FIELD_WIDTH as usize)+*;
+                    return self.set(offset, #ty::FIELD_WIDTH, val as u64);
+                }
+                ));
+        current_types.push(ty.clone());
+    }
+    impls
+}
+
+fn get_tests_impl(struct_name: &Tokens, fields: &[(String, Tokens)]) -> Vec<Tokens> {
+    let mut field_types = Vec::new();
+    for &(ref _name, ref ty) in fields {
+        field_types.push(ty.clone());
+    }
+    let field_types2 = field_types.clone();
+    let mut impls = Vec::new();
+    impls.push(quote!(
+                #[test]
+                fn test_total_size() {
+                    let total_size = #(#field_types::FIELD_WIDTH as usize)+*;
+                    assert_eq!(total_size % 8, 0);
+                }));
+    impls.push(quote!(#[test]
+    fn test_bits_boundary() {
+        let fields_sizes = vec![#(#field_types2::FIELD_WIDTH as usize),*];
+        let mut sum = 0usize;
+        for s in fields_sizes {
+            if sum % 64 == 0 {
+                assert!(s <= 64);
+            } else {
+                if (sum + s) % 64 != 0 {
+                    assert_eq!(sum / 64, (sum + s) / 64);
+                }
+            }
+            sum += s;
+        }
+    }));
+
+    for &(ref name, ref ty) in fields {
+        let testname = Ident::new(
+            format!("test_{}", name.as_str()).as_str(),
+            Span::call_site(),
+        );
+        let getter_ident = Ident::new(format!("get_{}", name.as_str()).as_str(), Span::call_site());
+        let setter_ident = Ident::new(format!("set_{}", name.as_str()).as_str(), Span::call_site());
+        impls.push(quote!(
+                #[test]
+                fn #testname() {
+                    let mut a = #struct_name::new();
+                    assert_eq!(a.#getter_ident() as u64, 0);
+                    a.#setter_ident(#ty::FIELD_MAX as <#ty as BitFieldSpecifier>::DefaultFieldType);
+                    assert_eq!(a.#getter_ident() as u64, #ty::FIELD_MAX);
+                }
+                ));
+    }
+    impls
+}
+
+fn get_bits_impl(name: &Tokens) -> Tokens {
+    quote!(
+        impl #name {
+            #[inline]
+            fn check_access(&self, offset: usize, width: u8) {
+                debug_assert!(width <= 64);
+                debug_assert!(offset / 8 < self.data.len());
+                debug_assert!((offset + (width as usize)) <= (self.data.len() * 8));
+            }
+
+
+            #[inline]
+            pub fn get_bit(&self, offset: usize) -> bool {
+                self.check_access(offset, 1);
+
+                let byte_index = offset / 8;
+                let bit_offset = offset % 8;
+
+                let byte = self.data[byte_index];
+                let mask = 1 << bit_offset;
+
+                byte & mask == mask
+            }
+
+            #[inline]
+            pub fn set_bit(&mut self, offset: usize, val: bool) {
+                self.check_access(offset, 1);
+
+                let byte_index = offset / 8;
+                let bit_offset = offset % 8;
+
+                let byte = &mut self.data[byte_index];
+                let mask = 1 << bit_offset;
+
+                if val {
+                    *byte |= mask;
+                } else {
+                    *byte &= !mask;
+                }
+            }
+
+            #[inline]
+            pub fn get(&self, offset: usize, width: u8) -> u64 {
+                self.check_access(offset, width);
+                let mut val = 0;
+
+                for i in 0..(width as usize) {
+                    if self.get_bit(i + offset) {
+                        val |= 1 << i;
+                    }
+                }
+
+                val
+            }
+
+            #[inline]
+            pub fn set(&mut self, offset: usize, width: u8, val: u64) {
+                self.check_access(offset, width);
+
+                for i in 0..(width as usize) {
+                    let mask = 1 << i;
+                    let val_bit_is_set = val & mask == mask;
+                    self.set_bit(i + offset, val_bit_is_set);
+                }
+            }
+        }
+        )
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn struct_name_parsing_valid() {
+        assert_eq!(get_struct_name("MyBitFieldSchema"), Some("MyBitField"));
+    }
+
+    #[test]
+    fn struct_name_parsing_invalid() {
+        assert_eq!(get_struct_name("MyBitFieldS"), None);
+    }
+
+    #[test]
+    fn end_to_end() {
+        let input: DeriveInput = parse_quote! {
+            #[passthrough(derive(Clone))]
+            struct MyBitFieldSchema {
+                a: BitField1,
+                b: BitField2,
+                c: BitField5,
+            }
+        };
+
+        let expected = quote! {
+            #[derive(Clone)]
+            struct MyBitField {
+                data: [u8; (BitField1::FIELD_WIDTH as usize
+                    + BitField2::FIELD_WIDTH as usize
+                    + BitField5::FIELD_WIDTH as usize) / 8],
+            }
+            impl MyBitField {
+                pub fn new() -> MyBitField {
+                    MyBitField {
+                        data: [0; (BitField1::FIELD_WIDTH as usize
+                            + BitField2::FIELD_WIDTH as usize
+                            + BitField5::FIELD_WIDTH as usize) / 8],
+                    }
+                }
+            }
+            impl MyBitField {
+                #[inline]
+                fn check_access(&self, offset: usize, width: u8) {
+                    debug_assert!(width <= 64);
+                    debug_assert!(offset / 8 < self.data.len());
+                    debug_assert!((offset + (width as usize)) <= (self.data.len() * 8));
+                }
+                #[inline]
+                pub fn get_bit(&self, offset: usize) -> bool {
+                    self.check_access(offset, 1);
+                    let byte_index = offset / 8;
+                    let bit_offset = offset % 8;
+                    let byte = self.data[byte_index];
+                    let mask = 1 << bit_offset;
+                    byte & mask == mask
+                }
+                #[inline]
+                pub fn set_bit(&mut self, offset: usize, val: bool) {
+                    self.check_access(offset, 1);
+                    let byte_index = offset / 8;
+                    let bit_offset = offset % 8;
+                    let byte = &mut self.data[byte_index];
+                    let mask = 1 << bit_offset;
+                    if val {
+                        *byte |= mask;
+                    } else {
+                        *byte &= !mask;
+                    }
+                }
+                #[inline]
+                pub fn get(&self, offset: usize, width: u8) -> u64 {
+                    self.check_access(offset, width);
+                    let mut val = 0;
+                    for i in 0..(width as usize) {
+                        if self.get_bit(i + offset) {
+                            val |= 1 << i;
+                        }
+                    }
+                    val
+                }
+                #[inline]
+                pub fn set(&mut self, offset: usize, width: u8, val: u64) {
+                    self.check_access(offset, width);
+                    for i in 0..(width as usize) {
+                        let mask = 1 << i;
+                        let val_bit_is_set = val & mask == mask;
+                        self.set_bit(i + offset, val_bit_is_set);
+                    }
+                }
+            }
+            impl MyBitField {
+                pub fn get_a(&self) -> <BitField1 as BitFieldSpecifier>::DefaultFieldType {
+                    let offset = BitField0::FIELD_WIDTH as usize;
+                    return self.get(offset, BitField1::FIELD_WIDTH)
+                        as <BitField1 as BitFieldSpecifier>::DefaultFieldType;
+                }
+                pub fn set_a(&mut self, val: <BitField1 as BitFieldSpecifier>::DefaultFieldType) {
+                    debug_assert!((val as u64) <= BitField1::FIELD_MAX);
+                    let offset = BitField0::FIELD_WIDTH as usize;
+                    return self.set(offset, BitField1::FIELD_WIDTH, val as u64);
+                }
+                pub fn get_b(&self) -> <BitField2 as BitFieldSpecifier>::DefaultFieldType {
+                    let offset = BitField0::FIELD_WIDTH as usize + BitField1::FIELD_WIDTH as usize;
+                    return self.get(offset, BitField2::FIELD_WIDTH)
+                        as <BitField2 as BitFieldSpecifier>::DefaultFieldType;
+                }
+                pub fn set_b(&mut self, val: <BitField2 as BitFieldSpecifier>::DefaultFieldType) {
+                    debug_assert!((val as u64) <= BitField2::FIELD_MAX);
+                    let offset = BitField0::FIELD_WIDTH as usize + BitField1::FIELD_WIDTH as usize;
+                    return self.set(offset, BitField2::FIELD_WIDTH, val as u64);
+                }
+                pub fn get_c(&self) -> <BitField5 as BitFieldSpecifier>::DefaultFieldType {
+                    let offset = BitField0::FIELD_WIDTH as usize
+                        + BitField1::FIELD_WIDTH as usize
+                        + BitField2::FIELD_WIDTH as usize;
+                    return self.get(offset, BitField5::FIELD_WIDTH)
+                        as <BitField5 as BitFieldSpecifier>::DefaultFieldType;
+                }
+                pub fn set_c(&mut self, val: <BitField5 as BitFieldSpecifier>::DefaultFieldType) {
+                    debug_assert!((val as u64) <= BitField5::FIELD_MAX);
+                    let offset = BitField0::FIELD_WIDTH as usize
+                        + BitField1::FIELD_WIDTH as usize
+                        + BitField2::FIELD_WIDTH as usize;
+                    return self.set(offset, BitField5::FIELD_WIDTH, val as u64);
+                }
+            }
+            #[cfg(test)]
+            mod test_MyBitField {
+                use super::*;
+                #[test]
+                fn test_total_size() {
+                    let total_size = BitField1::FIELD_WIDTH as usize
+                        + BitField2::FIELD_WIDTH as usize
+                        + BitField5::FIELD_WIDTH as usize;
+                    assert_eq!(total_size % 8, 0);
+                }
+                #[test]
+                fn test_bits_boundary() {
+                    let fields_sizes = vec![
+                        BitField1::FIELD_WIDTH as usize,
+                        BitField2::FIELD_WIDTH as usize,
+                        BitField5::FIELD_WIDTH as usize
+                    ];
+                    let mut sum = 0usize;
+                    for s in fields_sizes {
+                        if sum % 64 == 0 {
+                            assert!(s <= 64);
+                        } else {
+                            if (sum + s) % 64 != 0 {
+                                assert_eq!(sum / 64, (sum + s) / 64);
+                            }
+                        }
+                        sum += s;
+                    }
+                }
+                #[test]
+                fn test_a() {
+                    let mut a = MyBitField::new();
+                    assert_eq!(a.get_a() as u64, 0);
+                    a.set_a(BitField1::FIELD_MAX as
+                            <BitField1 as BitFieldSpecifier>::DefaultFieldType);
+                    assert_eq!(a.get_a() as u64, BitField1::FIELD_MAX);
+                }
+                #[test]
+                fn test_b() {
+                    let mut a = MyBitField::new();
+                    assert_eq!(a.get_b() as u64, 0);
+                    a.set_b(BitField2::FIELD_MAX as
+                            <BitField2 as BitFieldSpecifier>::DefaultFieldType);
+                    assert_eq!(a.get_b() as u64, BitField2::FIELD_MAX);
+                }
+                #[test]
+                fn test_c() {
+                    let mut a = MyBitField::new();
+                    assert_eq!(a.get_c() as u64, 0);
+                    a.set_c(BitField5::FIELD_MAX as
+                            <BitField5 as BitFieldSpecifier>::DefaultFieldType);
+                    assert_eq!(a.get_c() as u64, BitField5::FIELD_MAX);
+                }
+            }
+        };
+
+        assert_eq!(bitfield_impl(input), expected);
+    }
+}