valence_protocol_macros/
encode.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use quote::quote;
3use syn::spanned::Spanned;
4use syn::{parse2, Data, DeriveInput, Error, Fields, LitInt, Result};
5
6use crate::{add_trait_bounds, pair_variants_with_discriminants};
7
8pub(super) fn derive_encode(item: TokenStream) -> Result<TokenStream> {
9    let mut input = parse2::<DeriveInput>(item)?;
10
11    let input_name = input.ident;
12
13    add_trait_bounds(
14        &mut input.generics,
15        quote!(::valence_protocol::__private::Encode),
16    );
17
18    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
19
20    match input.data {
21        Data::Struct(struct_) => {
22            let encode_fields = match &struct_.fields {
23                Fields::Named(fields) => fields
24                    .named
25                    .iter()
26                    .map(|f| {
27                        let name = &f.ident.as_ref().unwrap();
28                        let ctx = format!("failed to encode field `{name}` in `{input_name}`");
29                        quote! {
30                            self.#name.encode(&mut _w).context(#ctx)?;
31                        }
32                    })
33                    .collect(),
34                Fields::Unnamed(fields) => (0..fields.unnamed.len())
35                    .map(|i| {
36                        let lit = LitInt::new(&i.to_string(), Span::call_site());
37                        let ctx = format!("failed to encode field `{lit}` in `{input_name}`");
38                        quote! {
39                            self.#lit.encode(&mut _w).context(#ctx)?;
40                        }
41                    })
42                    .collect(),
43                Fields::Unit => TokenStream::new(),
44            };
45
46            Ok(quote! {
47                #[allow(unused_imports)]
48                impl #impl_generics ::valence_protocol::__private::Encode for #input_name #ty_generics
49                #where_clause
50                {
51                    fn encode(&self, mut _w: impl ::std::io::Write) -> ::valence_protocol::__private::Result<()> {
52                        use ::valence_protocol::__private::{Encode, Context};
53
54                        #encode_fields
55
56                        Ok(())
57                    }
58                }
59            })
60        }
61        Data::Enum(enum_) => {
62            let variants = pair_variants_with_discriminants(enum_.variants)?;
63
64            let encode_arms = variants
65                .iter()
66                .map(|(disc, variant)| {
67                    let variant_name = &variant.ident;
68
69                    let disc_ctx = format!(
70                        "failed to encode enum discriminant {disc} for variant `{variant_name}` \
71                         in `{input_name}`",
72                    );
73
74                    match &variant.fields {
75                        Fields::Named(fields) => {
76                            let field_names = fields
77                                .named
78                                .iter()
79                                .map(|f| f.ident.as_ref().unwrap())
80                                .collect::<Vec<_>>();
81
82                            let encode_fields = field_names
83                                .iter()
84                                .map(|name| {
85                                    let ctx = format!(
86                                        "failed to encode field `{name}` in variant \
87                                         `{variant_name}` in `{input_name}`",
88                                    );
89
90                                    quote! {
91                                        #name.encode(&mut _w).context(#ctx)?;
92                                    }
93                                })
94                                .collect::<TokenStream>();
95
96                            quote! {
97                                Self::#variant_name { #(#field_names,)* } => {
98                                    VarInt(#disc).encode(&mut _w).context(#disc_ctx)?;
99
100                                    #encode_fields
101                                    Ok(())
102                                }
103                            }
104                        }
105                        Fields::Unnamed(fields) => {
106                            let field_names = (0..fields.unnamed.len())
107                                .map(|i| Ident::new(&format!("_{i}"), Span::call_site()))
108                                .collect::<Vec<_>>();
109
110                            let encode_fields = field_names
111                                .iter()
112                                .map(|name| {
113                                    let ctx = format!(
114                                        "failed to encode field `{name}` in variant \
115                                         `{variant_name}` in `{input_name}`"
116                                    );
117
118                                    quote! {
119                                        #name.encode(&mut _w).context(#ctx)?;
120                                    }
121                                })
122                                .collect::<TokenStream>();
123
124                            quote! {
125                                Self::#variant_name(#(#field_names,)*) => {
126                                    VarInt(#disc).encode(&mut _w).context(#disc_ctx)?;
127
128                                    #encode_fields
129                                    Ok(())
130                                }
131                            }
132                        }
133                        Fields::Unit => quote! {
134                            Self::#variant_name => Ok(
135                                VarInt(#disc)
136                                    .encode(&mut _w)
137                                    .context(#disc_ctx)?
138                            ),
139                        },
140                    }
141                })
142                .collect::<TokenStream>();
143
144            Ok(quote! {
145                #[allow(unused_imports, unreachable_code)]
146                impl #impl_generics ::valence_protocol::__private::Encode for #input_name #ty_generics
147                #where_clause
148                {
149                    fn encode(&self, mut _w: impl ::std::io::Write) -> ::valence_protocol::__private::Result<()> {
150                        use ::valence_protocol::__private::{Encode, VarInt, Context};
151
152                        match self {
153                            #encode_arms
154                            _ => unreachable!(),
155                        }
156                    }
157                }
158            })
159        }
160        Data::Union(u) => Err(Error::new(
161            u.union_token.span(),
162            "cannot derive `Encode` on unions",
163        )),
164    }
165}