valence_protocol/
decode.rs
1#[cfg(feature = "encryption")]
2use aes::cipher::{generic_array::GenericArray, BlockDecryptMut, BlockSizeUser, KeyIvInit};
3use anyhow::{bail, ensure, Context};
4use bytes::{Buf, BytesMut};
5
6use crate::var_int::{VarInt, VarIntDecodeError};
7#[cfg(feature = "compression")]
8use crate::CompressionThreshold;
9use crate::{Decode, Packet, MAX_PACKET_SIZE};
10
11#[cfg(feature = "encryption")]
14type Cipher = cfb8::Decryptor<aes::Aes128>;
15
16#[derive(Default)]
17pub struct PacketDecoder {
18 buf: BytesMut,
19 #[cfg(feature = "compression")]
20 decompress_buf: BytesMut,
21 #[cfg(feature = "compression")]
22 threshold: CompressionThreshold,
23 #[cfg(feature = "encryption")]
24 cipher: Option<Cipher>,
25}
26
27impl PacketDecoder {
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn try_next_packet(&mut self) -> anyhow::Result<Option<PacketFrame>> {
33 let mut r = &self.buf[..];
34
35 let packet_len = match VarInt::decode_partial(&mut r) {
36 Ok(len) => len,
37 Err(VarIntDecodeError::Incomplete) => return Ok(None),
38 Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"),
39 };
40
41 ensure!(
42 (0..=MAX_PACKET_SIZE).contains(&packet_len),
43 "packet length of {packet_len} is out of bounds"
44 );
45
46 if r.len() < packet_len as usize {
47 return Ok(None);
49 }
50
51 let packet_len_len = VarInt(packet_len).written_size();
52
53 let mut data;
54
55 #[cfg(feature = "compression")]
56 if self.threshold.0 >= 0 {
57 use std::io::Write;
58
59 use bytes::BufMut;
60 use flate2::write::ZlibDecoder;
61
62 r = &r[..packet_len as usize];
63
64 let data_len = VarInt::decode(&mut r)?.0;
65
66 ensure!(
67 (0..MAX_PACKET_SIZE).contains(&data_len),
68 "decompressed packet length of {data_len} is out of bounds"
69 );
70
71 if data_len > 0 {
73 ensure!(
74 data_len > self.threshold.0,
75 "decompressed packet length of {data_len} is <= the compression threshold of \
76 {}",
77 self.threshold.0
78 );
79
80 debug_assert!(self.decompress_buf.is_empty());
81
82 self.decompress_buf.put_bytes(0, data_len as usize);
83
84 let mut z = ZlibDecoder::new(&mut self.decompress_buf[..]);
86
87 z.write_all(r)?;
88
89 ensure!(
90 z.finish()?.is_empty(),
91 "decompressed packet length is shorter than expected"
92 );
93
94 let total_packet_len = VarInt(packet_len).written_size() + packet_len as usize;
95
96 self.buf.advance(total_packet_len);
97
98 data = self.decompress_buf.split();
99 } else {
100 debug_assert_eq!(data_len, 0);
101
102 ensure!(
103 r.len() <= self.threshold.0 as usize,
104 "uncompressed packet length of {} exceeds compression threshold of {}",
105 r.len(),
106 self.threshold.0
107 );
108
109 let remaining_len = r.len();
110
111 self.buf.advance(packet_len_len + 1);
112
113 data = self.buf.split_to(remaining_len);
114 }
115 } else {
116 self.buf.advance(packet_len_len);
117 data = self.buf.split_to(packet_len as usize);
118 }
119
120 #[cfg(not(feature = "compression"))]
121 {
122 self.buf.advance(packet_len_len);
123 data = self.buf.split_to(packet_len as usize);
124 }
125
126 r = &data[..];
128 let packet_id = VarInt::decode(&mut r)
129 .context("failed to decode packet ID")?
130 .0;
131
132 data.advance(data.len() - r.len());
133
134 Ok(Some(PacketFrame {
135 id: packet_id,
136 body: data,
137 }))
138 }
139
140 #[cfg(feature = "compression")]
141 pub fn compression(&self) -> CompressionThreshold {
142 self.threshold
143 }
144
145 #[cfg(feature = "compression")]
146 pub fn set_compression(&mut self, threshold: CompressionThreshold) {
147 self.threshold = threshold;
148 }
149
150 #[cfg(feature = "encryption")]
151 pub fn enable_encryption(&mut self, key: &[u8; 16]) {
152 assert!(self.cipher.is_none(), "encryption is already enabled");
153
154 let mut cipher = Cipher::new_from_slices(key, key).expect("invalid key");
155
156 Self::decrypt_bytes(&mut cipher, &mut self.buf);
158
159 self.cipher = Some(cipher);
160 }
161
162 #[cfg(feature = "encryption")]
165 fn decrypt_bytes(cipher: &mut Cipher, bytes: &mut [u8]) {
166 for chunk in bytes.chunks_mut(Cipher::block_size()) {
167 let gen_arr = GenericArray::from_mut_slice(chunk);
168 cipher.decrypt_block_mut(gen_arr);
169 }
170 }
171
172 pub fn queue_bytes(&mut self, mut bytes: BytesMut) {
173 #![allow(unused_mut)]
174
175 #[cfg(feature = "encryption")]
176 if let Some(cipher) = &mut self.cipher {
177 Self::decrypt_bytes(cipher, &mut bytes);
178 }
179
180 self.buf.unsplit(bytes);
181 }
182
183 pub fn queue_slice(&mut self, bytes: &[u8]) {
184 #[cfg(feature = "encryption")]
185 let len = self.buf.len();
186
187 self.buf.extend_from_slice(bytes);
188
189 #[cfg(feature = "encryption")]
190 if let Some(cipher) = &mut self.cipher {
191 let slice = &mut self.buf[len..];
192 Self::decrypt_bytes(cipher, slice);
193 }
194 }
195
196 pub fn take_capacity(&mut self) -> BytesMut {
197 self.buf.split_off(self.buf.len())
198 }
199
200 pub fn reserve(&mut self, additional: usize) {
201 self.buf.reserve(additional);
202 }
203}
204
205#[derive(Clone, Debug)]
206pub struct PacketFrame {
207 pub id: i32,
209 pub body: BytesMut,
211}
212
213impl PacketFrame {
214 pub fn decode<'a, P>(&'a self) -> anyhow::Result<P>
218 where
219 P: Packet + Decode<'a>,
220 {
221 ensure!(
222 P::ID == self.id,
223 "packet ID mismatch while decoding '{}': expected {}, got {}",
224 P::NAME,
225 P::ID,
226 self.id
227 );
228
229 let mut r = &self.body[..];
230
231 let pkt = P::decode(&mut r)?;
232
233 ensure!(
234 r.is_empty(),
235 "missed {} bytes while decoding '{}'",
236 r.len(),
237 P::NAME
238 );
239
240 Ok(pkt)
241 }
242}