summary refs log tree commit diff
path: root/sys_util/poll_token_derive/poll_token_derive.rs
blob: 7b7baac1c088abd0a9e6cd23a000ea20d4f60372 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#![recursion_limit = "128"]

extern crate proc_macro;

use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, Index, Member, Variant};

#[cfg(test)]
mod tests;

// The method for packing an enum into a u64 is as follows:
// 1) Reserve the lowest "ceil(log_2(x))" bits where x is the number of enum variants.
// 2) Store the enum variant's index (0-based index based on order in the enum definition) in
//    reserved bits.
// 3) If there is data in the enum variant, store the data in remaining bits.
// The method for unpacking is as follows
// 1) Mask the raw token to just the reserved bits
// 2) Match the reserved bits to the enum variant token.
// 3) If the indicated enum variant had data, extract it from the unreserved bits.

// Calculates the number of bits needed to store the variant index. Essentially the log base 2
// of the number of variants, rounded up.
fn variant_bits(variants: &[Variant]) -> u32 {
    if variants.is_empty() {
        // The degenerate case of no variants.
        0
    } else {
        variants.len().next_power_of_two().trailing_zeros()
    }
}

// Name of the field if it has one, otherwise 0 assuming this is the zeroth
// field of a tuple variant.
fn field_member(field: &Field) -> Member {
    match &field.ident {
        Some(name) => Member::Named(name.clone()),
        None => Member::Unnamed(Index::from(0)),
    }
}

// Generates the function body for `as_raw_token`.
fn generate_as_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
    let variant_bits = variant_bits(variants);

    // Each iteration corresponds to one variant's match arm.
    let cases = variants.iter().enumerate().map(|(index, variant)| {
        let variant_name = &variant.ident;
        let index = index as u64;

        // The capture string is for everything between the variant identifier and the `=>` in
        // the match arm: the variant's data capture.
        let capture = variant.fields.iter().next().map(|field| {
            let member = field_member(&field);
            quote!({ #member: data })
        });

        // The modifier string ORs the variant index with extra bits from the variant data
        // field.
        let modifier = match variant.fields {
            Fields::Named(_) | Fields::Unnamed(_) => Some(quote! {
                | ((data as u64) << #variant_bits)
            }),
            Fields::Unit => None,
        };

        // Assembly of the match arm.
        quote! {
            #enum_name::#variant_name #capture => #index #modifier
        }
    });

    quote! {
        match *self {
            #(
                #cases,
            )*
        }
    }
}

// Generates the function body for `from_raw_token`.
fn generate_from_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
    let variant_bits = variant_bits(variants);
    let variant_mask = ((1 << variant_bits) - 1) as u64;

    // Each iteration corresponds to one variant's match arm.
    let cases = variants.iter().enumerate().map(|(index, variant)| {
        let variant_name = &variant.ident;
        let index = index as u64;

        // The data string is for extracting the enum variant's data bits out of the raw token
        // data, which includes both variant index and data bits.
        let data = variant.fields.iter().next().map(|field| {
            let member = field_member(&field);
            let ty = &field.ty;
            quote!({ #member: (data >> #variant_bits) as #ty })
        });

        // Assembly of the match arm.
        quote! {
            #index => #enum_name::#variant_name #data
        }
    });

    quote! {
        // The match expression only matches the bits for the variant index.
        match data & #variant_mask {
            #(
                #cases,
            )*
            _ => unreachable!(),
        }
    }
}

// The proc_macro::TokenStream type can only be constructed from within a
// procedural macro, meaning that unit tests are not able to invoke `fn
// poll_token` below as an ordinary Rust function. We factor out the logic into
// a signature that deals with Syn and proc-macro2 types only which are not
// restricted to a procedural macro invocation.
fn poll_token_inner(input: DeriveInput) -> TokenStream {
    let variants: Vec<Variant> = match input.data {
        Data::Enum(data) => data.variants.into_iter().collect(),
        Data::Struct(_) | Data::Union(_) => panic!("input must be an enum"),
    };

    for variant in &variants {
        assert!(variant.fields.iter().count() <= 1);
    }

    // Given our basic model of a user given enum that is suitable as a token, we generate the
    // implementation. The implementation is NOT always well formed, such as when a variant's data
    // type is not bit shiftable or castable to u64, but we let Rust generate such errors as it
    // would be difficult to detect every kind of error. Importantly, every implementation that we
    // generate here and goes on to compile succesfully is sound.

    let enum_name = input.ident;
    let as_raw_token = generate_as_raw_token(&enum_name, &variants);
    let from_raw_token = generate_from_raw_token(&enum_name, &variants);

    quote! {
        impl PollToken for #enum_name {
            fn as_raw_token(&self) -> u64 {
                #as_raw_token
            }

            fn from_raw_token(data: u64) -> Self {
                #from_raw_token
            }
        }
    }
}

/// Implements the PollToken trait for a given `enum`.
///
/// There are limitations on what `enum`s this custom derive will work on:
///
/// * Each variant must be a unit variant (no data), or have a single (un)named data field.
/// * If a variant has data, it must be a primitive type castable to and from a `u64`.
/// * If a variant data has size greater than or equal to a `u64`, its most significant bits must be
///   zero. The number of bits truncated is equal to the number of bits used to store the variant
///   index plus the number of bits above 64.
#[proc_macro_derive(PollToken)]
pub fn poll_token(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    poll_token_inner(input).into()
}