valence_protocol_macros/
decode.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::spanned::Spanned;
4use syn::{parse2, parse_quote, Data, DeriveInput, Error, Fields, Result};
5
6use crate::{add_trait_bounds, decode_split_for_impl, pair_variants_with_discriminants};
7
8pub(super) fn derive_decode(item: TokenStream) -> Result<TokenStream> {
9    let mut input = parse2::<DeriveInput>(item)?;
10
11    let input_name = input.ident;
12
13    if input.generics.lifetimes().count() > 1 {
14        return Err(Error::new(
15            input.generics.params.span(),
16            "type deriving `Decode` must have no more than one lifetime",
17        ));
18    }
19
20    // Use the lifetime specified in the type definition or just use `'a` if not
21    // present.
22    let lifetime = input
23        .generics
24        .lifetimes()
25        .next()
26        .map_or_else(|| parse_quote!('a), |l| l.lifetime.clone());
27
28    match input.data {
29        Data::Struct(struct_) => {
30            let decode_fields = match struct_.fields {
31                Fields::Named(fields) => {
32                    let init = fields.named.iter().map(|f| {
33                        let name = f.ident.as_ref().unwrap();
34                        let ctx = format!("failed to decode field `{name}` in `{input_name}`");
35                        quote! {
36                            #name: Decode::decode(_r).context(#ctx)?,
37                        }
38                    });
39
40                    quote! {
41                        Self {
42                            #(#init)*
43                        }
44                    }
45                }
46                Fields::Unnamed(fields) => {
47                    let init = (0..fields.unnamed.len())
48                        .map(|i| {
49                            let ctx = format!("failed to decode field `{i}` in `{input_name}`");
50                            quote! {
51                                Decode::decode(_r).context(#ctx)?,
52                            }
53                        })
54                        .collect::<TokenStream>();
55
56                    quote! {
57                        Self(#init)
58                    }
59                }
60                Fields::Unit => quote!(Self),
61            };
62
63            add_trait_bounds(
64                &mut input.generics,
65                quote!(::valence_protocol::__private::Decode<#lifetime>),
66            );
67
68            let (impl_generics, ty_generics, where_clause) =
69                decode_split_for_impl(input.generics, lifetime.clone());
70
71            Ok(quote! {
72                #[allow(unused_imports)]
73                impl #impl_generics ::valence_protocol::__private::Decode<#lifetime> for #input_name #ty_generics
74                #where_clause
75                {
76                    fn decode(_r: &mut &#lifetime [u8]) -> ::valence_protocol::__private::Result<Self> {
77                        use ::valence_protocol::__private::{Decode, Context, ensure};
78
79                        Ok(#decode_fields)
80                    }
81                }
82            })
83        }
84        Data::Enum(enum_) => {
85            let variants = pair_variants_with_discriminants(enum_.variants)?;
86
87            let decode_arms = variants
88                .iter()
89                .map(|(disc, variant)| {
90                    let name = &variant.ident;
91
92                    match &variant.fields {
93                        Fields::Named(fields) => {
94                            let fields = fields
95                                .named
96                                .iter()
97                                .map(|f| {
98                                    let field = f.ident.as_ref().unwrap();
99                                    let ctx = format!(
100                                        "failed to decode field `{field}` in variant `{name}` in \
101                                         `{input_name}`",
102                                    );
103                                    quote! {
104                                        #field: Decode::decode(_r).context(#ctx)?,
105                                    }
106                                })
107                                .collect::<TokenStream>();
108
109                            quote! {
110                                #disc => Ok(Self::#name { #fields }),
111                            }
112                        }
113                        Fields::Unnamed(fields) => {
114                            let init = (0..fields.unnamed.len())
115                                .map(|i| {
116                                    let ctx = format!(
117                                        "failed to decode field `{i}` in variant `{name}` in \
118                                         `{input_name}`",
119                                    );
120                                    quote! {
121                                        Decode::decode(_r).context(#ctx)?,
122                                    }
123                                })
124                                .collect::<TokenStream>();
125
126                            quote! {
127                                #disc => Ok(Self::#name(#init)),
128                            }
129                        }
130                        Fields::Unit => quote!(#disc => Ok(Self::#name),),
131                    }
132                })
133                .collect::<TokenStream>();
134
135            add_trait_bounds(
136                &mut input.generics,
137                quote!(::valence_protocol::__private::Decode<#lifetime>),
138            );
139
140            let (impl_generics, ty_generics, where_clause) =
141                decode_split_for_impl(input.generics, lifetime.clone());
142
143            Ok(quote! {
144                #[allow(unused_imports)]
145                impl #impl_generics ::valence_protocol::__private::Decode<#lifetime> for #input_name #ty_generics
146                #where_clause
147                {
148                    fn decode(_r: &mut &#lifetime [u8]) -> ::valence_protocol::__private::Result<Self> {
149                        use ::valence_protocol::__private::{Decode, Context, VarInt, bail};
150
151                        let ctx = concat!("failed to decode enum discriminant in `", stringify!(#input_name), "`");
152                        let disc = VarInt::decode(_r).context(ctx)?.0;
153                        match disc {
154                            #decode_arms
155                            n => bail!("unexpected enum discriminant {} in `{}`", disc, stringify!(#input_name)),
156                        }
157                    }
158                }
159            })
160        }
161        Data::Union(u) => Err(Error::new(
162            u.union_token.span(),
163            "cannot derive `Decode` on unions",
164        )),
165    }
166}