valence_protocol_macros/
decode.rs
1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::spanned::Spanned;
4use syn::{parse2, parse_quote, Data, DeriveInput, Error, Fields, Result};
5
6use crate::{add_trait_bounds, decode_split_for_impl, pair_variants_with_discriminants};
7
8pub(super) fn derive_decode(item: TokenStream) -> Result<TokenStream> {
9 let mut input = parse2::<DeriveInput>(item)?;
10
11 let input_name = input.ident;
12
13 if input.generics.lifetimes().count() > 1 {
14 return Err(Error::new(
15 input.generics.params.span(),
16 "type deriving `Decode` must have no more than one lifetime",
17 ));
18 }
19
20 let lifetime = input
23 .generics
24 .lifetimes()
25 .next()
26 .map_or_else(|| parse_quote!('a), |l| l.lifetime.clone());
27
28 match input.data {
29 Data::Struct(struct_) => {
30 let decode_fields = match struct_.fields {
31 Fields::Named(fields) => {
32 let init = fields.named.iter().map(|f| {
33 let name = f.ident.as_ref().unwrap();
34 let ctx = format!("failed to decode field `{name}` in `{input_name}`");
35 quote! {
36 #name: Decode::decode(_r).context(#ctx)?,
37 }
38 });
39
40 quote! {
41 Self {
42 #(#init)*
43 }
44 }
45 }
46 Fields::Unnamed(fields) => {
47 let init = (0..fields.unnamed.len())
48 .map(|i| {
49 let ctx = format!("failed to decode field `{i}` in `{input_name}`");
50 quote! {
51 Decode::decode(_r).context(#ctx)?,
52 }
53 })
54 .collect::<TokenStream>();
55
56 quote! {
57 Self(#init)
58 }
59 }
60 Fields::Unit => quote!(Self),
61 };
62
63 add_trait_bounds(
64 &mut input.generics,
65 quote!(::valence_protocol::__private::Decode<#lifetime>),
66 );
67
68 let (impl_generics, ty_generics, where_clause) =
69 decode_split_for_impl(input.generics, lifetime.clone());
70
71 Ok(quote! {
72 #[allow(unused_imports)]
73 impl #impl_generics ::valence_protocol::__private::Decode<#lifetime> for #input_name #ty_generics
74 #where_clause
75 {
76 fn decode(_r: &mut &#lifetime [u8]) -> ::valence_protocol::__private::Result<Self> {
77 use ::valence_protocol::__private::{Decode, Context, ensure};
78
79 Ok(#decode_fields)
80 }
81 }
82 })
83 }
84 Data::Enum(enum_) => {
85 let variants = pair_variants_with_discriminants(enum_.variants)?;
86
87 let decode_arms = variants
88 .iter()
89 .map(|(disc, variant)| {
90 let name = &variant.ident;
91
92 match &variant.fields {
93 Fields::Named(fields) => {
94 let fields = fields
95 .named
96 .iter()
97 .map(|f| {
98 let field = f.ident.as_ref().unwrap();
99 let ctx = format!(
100 "failed to decode field `{field}` in variant `{name}` in \
101 `{input_name}`",
102 );
103 quote! {
104 #field: Decode::decode(_r).context(#ctx)?,
105 }
106 })
107 .collect::<TokenStream>();
108
109 quote! {
110 #disc => Ok(Self::#name { #fields }),
111 }
112 }
113 Fields::Unnamed(fields) => {
114 let init = (0..fields.unnamed.len())
115 .map(|i| {
116 let ctx = format!(
117 "failed to decode field `{i}` in variant `{name}` in \
118 `{input_name}`",
119 );
120 quote! {
121 Decode::decode(_r).context(#ctx)?,
122 }
123 })
124 .collect::<TokenStream>();
125
126 quote! {
127 #disc => Ok(Self::#name(#init)),
128 }
129 }
130 Fields::Unit => quote!(#disc => Ok(Self::#name),),
131 }
132 })
133 .collect::<TokenStream>();
134
135 add_trait_bounds(
136 &mut input.generics,
137 quote!(::valence_protocol::__private::Decode<#lifetime>),
138 );
139
140 let (impl_generics, ty_generics, where_clause) =
141 decode_split_for_impl(input.generics, lifetime.clone());
142
143 Ok(quote! {
144 #[allow(unused_imports)]
145 impl #impl_generics ::valence_protocol::__private::Decode<#lifetime> for #input_name #ty_generics
146 #where_clause
147 {
148 fn decode(_r: &mut &#lifetime [u8]) -> ::valence_protocol::__private::Result<Self> {
149 use ::valence_protocol::__private::{Decode, Context, VarInt, bail};
150
151 let ctx = concat!("failed to decode enum discriminant in `", stringify!(#input_name), "`");
152 let disc = VarInt::decode(_r).context(ctx)?.0;
153 match disc {
154 #decode_arms
155 n => bail!("unexpected enum discriminant {} in `{}`", disc, stringify!(#input_name)),
156 }
157 }
158 }
159 })
160 }
161 Data::Union(u) => Err(Error::new(
162 u.union_token.span(),
163 "cannot derive `Decode` on unions",
164 )),
165 }
166}