valence_protocol/
encode.rs
1use std::io::Write;
2
3#[cfg(feature = "encryption")]
4use aes::cipher::generic_array::GenericArray;
5#[cfg(feature = "encryption")]
6use aes::cipher::{BlockEncryptMut, BlockSizeUser, KeyIvInit};
7use anyhow::ensure;
8use bytes::{BufMut, BytesMut};
9use tracing::warn;
10
11use crate::var_int::VarInt;
12use crate::{CompressionThreshold, Encode, Packet, MAX_PACKET_SIZE};
13
14#[cfg(feature = "encryption")]
17type Cipher = cfb8::Encryptor<aes::Aes128>;
18
19#[derive(Default)]
20pub struct PacketEncoder {
21 buf: BytesMut,
22 #[cfg(feature = "compression")]
23 compress_buf: Vec<u8>,
24 #[cfg(feature = "compression")]
25 threshold: CompressionThreshold,
26 #[cfg(feature = "encryption")]
27 cipher: Option<Cipher>,
28}
29
30impl PacketEncoder {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[inline]
36 pub fn append_bytes(&mut self, bytes: &[u8]) {
37 self.buf.extend_from_slice(bytes)
38 }
39
40 pub fn prepend_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
41 where
42 P: Packet + Encode,
43 {
44 let start_len = self.buf.len();
45 self.append_packet(pkt)?;
46
47 let end_len = self.buf.len();
48 let total_packet_len = end_len - start_len;
49
50 self.buf.put_bytes(0, total_packet_len);
54 self.buf.copy_within(..end_len, total_packet_len);
55 self.buf.copy_within(total_packet_len + start_len.., 0);
56 self.buf.truncate(end_len);
57
58 Ok(())
59 }
60
61 #[allow(clippy::needless_borrows_for_generic_args)]
62 pub fn append_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
63 where
64 P: Packet + Encode,
65 {
66 let start_len = self.buf.len();
67
68 pkt.encode_with_id((&mut self.buf).writer())?;
69
70 let data_len = self.buf.len() - start_len;
71
72 #[cfg(feature = "compression")]
73 if self.threshold.0 >= 0 {
74 use std::io::Read;
75
76 use flate2::bufread::ZlibEncoder;
77 use flate2::Compression;
78
79 if data_len > self.threshold.0 as usize {
80 let mut z = ZlibEncoder::new(&self.buf[start_len..], Compression::new(4));
81
82 self.compress_buf.clear();
83
84 let data_len_size = VarInt(data_len as i32).written_size();
85
86 let packet_len = data_len_size + z.read_to_end(&mut self.compress_buf)?;
87
88 ensure!(
89 packet_len <= MAX_PACKET_SIZE as usize,
90 "packet exceeds maximum length"
91 );
92
93 drop(z);
94
95 self.buf.truncate(start_len);
96
97 let mut writer = (&mut self.buf).writer();
98
99 VarInt(packet_len as i32).encode(&mut writer)?;
100 VarInt(data_len as i32).encode(&mut writer)?;
101 self.buf.extend_from_slice(&self.compress_buf);
102 } else {
103 let data_len_size = 1;
104 let packet_len = data_len_size + data_len;
105
106 ensure!(
107 packet_len <= MAX_PACKET_SIZE as usize,
108 "packet exceeds maximum length"
109 );
110
111 let packet_len_size = VarInt(packet_len as i32).written_size();
112
113 let data_prefix_len = packet_len_size + data_len_size;
114
115 self.buf.put_bytes(0, data_prefix_len);
116 self.buf
117 .copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
118
119 let mut front = &mut self.buf[start_len..];
120
121 VarInt(packet_len as i32).encode(&mut front)?;
122 VarInt(0).encode(front)?;
124 }
125
126 return Ok(());
127 }
128
129 let packet_len = data_len;
130
131 ensure!(
132 packet_len <= MAX_PACKET_SIZE as usize,
133 "packet exceeds maximum length"
134 );
135
136 let packet_len_size = VarInt(packet_len as i32).written_size();
137
138 self.buf.put_bytes(0, packet_len_size);
139 self.buf
140 .copy_within(start_len..start_len + data_len, start_len + packet_len_size);
141
142 let front = &mut self.buf[start_len..];
143 VarInt(packet_len as i32).encode(front)?;
144
145 Ok(())
146 }
147
148 pub fn take(&mut self) -> BytesMut {
151 #[cfg(feature = "encryption")]
152 if let Some(cipher) = &mut self.cipher {
153 for chunk in self.buf.chunks_mut(Cipher::block_size()) {
154 let gen_arr = GenericArray::from_mut_slice(chunk);
155 cipher.encrypt_block_mut(gen_arr);
156 }
157 }
158
159 self.buf.split()
160 }
161
162 pub fn clear(&mut self) {
163 self.buf.clear();
164 }
165
166 #[cfg(feature = "compression")]
167 pub fn set_compression(&mut self, threshold: CompressionThreshold) {
168 self.threshold = threshold;
169 }
170
171 #[cfg(feature = "encryption")]
180 pub fn enable_encryption(&mut self, key: &[u8; 16]) {
181 assert!(self.cipher.is_none(), "encryption is already enabled");
182 self.cipher = Some(Cipher::new_from_slices(key, key).expect("invalid key"));
183 }
184}
185
186pub trait WritePacket {
188 fn write_packet<P>(&mut self, packet: &P)
191 where
192 P: Packet + Encode,
193 {
194 if let Err(e) = self.write_packet_fallible(packet) {
195 warn!("failed to write packet '{}': {e:#}", P::NAME);
196 }
197 }
198
199 fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
202 where
203 P: Packet + Encode;
204
205 fn write_packet_bytes(&mut self, bytes: &[u8]);
208}
209
210impl<W: WritePacket> WritePacket for &mut W {
211 fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
212 where
213 P: Packet + Encode,
214 {
215 (*self).write_packet_fallible(packet)
216 }
217
218 fn write_packet_bytes(&mut self, bytes: &[u8]) {
219 (*self).write_packet_bytes(bytes)
220 }
221}
222
223impl<T: WritePacket> WritePacket for bevy_ecs::world::Mut<'_, T> {
224 fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
225 where
226 P: Packet + Encode,
227 {
228 self.as_mut().write_packet_fallible(packet)
229 }
230
231 fn write_packet_bytes(&mut self, bytes: &[u8]) {
232 self.as_mut().write_packet_bytes(bytes)
233 }
234}
235
236#[derive(Debug)]
241pub struct PacketWriter<'a> {
242 pub buf: &'a mut Vec<u8>,
243 pub threshold: CompressionThreshold,
244}
245
246impl<'a> PacketWriter<'a> {
247 pub fn new(buf: &'a mut Vec<u8>, threshold: CompressionThreshold) -> Self {
248 Self { buf, threshold }
249 }
250}
251
252impl WritePacket for PacketWriter<'_> {
253 #[cfg_attr(not(feature = "compression"), track_caller)]
254 fn write_packet_fallible<P>(&mut self, pkt: &P) -> anyhow::Result<()>
255 where
256 P: Packet + Encode,
257 {
258 let start = self.buf.len();
259
260 let res;
261
262 if self.threshold.0 >= 0 {
263 #[cfg(feature = "compression")]
264 {
265 res = encode_packet_compressed(self.buf, pkt, self.threshold.0 as u32);
266 }
267
268 #[cfg(not(feature = "compression"))]
269 {
270 panic!("\"compression\" feature must be enabled to write compressed packets");
271 }
272 } else {
273 res = encode_packet(self.buf, pkt)
274 };
275
276 if res.is_err() {
277 self.buf.truncate(start);
278 }
279
280 res
281 }
282
283 fn write_packet_bytes(&mut self, bytes: &[u8]) {
284 if let Err(e) = self.buf.write_all(bytes) {
285 warn!("failed to write packet bytes: {e:#}");
286 }
287 }
288}
289
290impl WritePacket for PacketEncoder {
291 fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
292 where
293 P: Packet + Encode,
294 {
295 self.append_packet(packet)
296 }
297
298 fn write_packet_bytes(&mut self, bytes: &[u8]) {
299 self.append_bytes(bytes)
300 }
301}
302
303fn encode_packet<P>(buf: &mut Vec<u8>, pkt: &P) -> anyhow::Result<()>
304where
305 P: Packet + Encode,
306{
307 let start_len = buf.len();
308
309 pkt.encode_with_id(&mut *buf)?;
310
311 let packet_len = buf.len() - start_len;
312
313 ensure!(
314 packet_len <= MAX_PACKET_SIZE as usize,
315 "packet exceeds maximum length"
316 );
317
318 let packet_len_size = VarInt(packet_len as i32).written_size();
319
320 buf.put_bytes(0, packet_len_size);
321 buf.copy_within(
322 start_len..start_len + packet_len,
323 start_len + packet_len_size,
324 );
325
326 let front = &mut buf[start_len..];
327 VarInt(packet_len as i32).encode(front)?;
328
329 Ok(())
330}
331
332#[cfg(feature = "compression")]
333#[allow(clippy::needless_borrows_for_generic_args)]
334fn encode_packet_compressed<P>(buf: &mut Vec<u8>, pkt: &P, threshold: u32) -> anyhow::Result<()>
335where
336 P: Packet + Encode,
337{
338 use std::io::Read;
339
340 use flate2::bufread::ZlibEncoder;
341 use flate2::Compression;
342
343 let start_len = buf.len();
344
345 pkt.encode_with_id(&mut *buf)?;
346
347 let data_len = buf.len() - start_len;
348
349 if data_len > threshold as usize {
350 let mut z = ZlibEncoder::new(&buf[start_len..], Compression::new(4));
351
352 let mut scratch = vec![];
353
354 let packet_len = VarInt(data_len as i32).written_size() + z.read_to_end(&mut scratch)?;
355
356 ensure!(
357 packet_len <= MAX_PACKET_SIZE as usize,
358 "packet exceeds maximum length"
359 );
360
361 drop(z);
362
363 buf.truncate(start_len);
364
365 VarInt(packet_len as i32).encode(&mut *buf)?;
366 VarInt(data_len as i32).encode(&mut *buf)?;
367 buf.extend_from_slice(&scratch);
368 } else {
369 let data_len_size = 1;
370 let packet_len = data_len_size + data_len;
371
372 ensure!(
373 packet_len <= MAX_PACKET_SIZE as usize,
374 "packet exceeds maximum length"
375 );
376
377 let packet_len_size = VarInt(packet_len as i32).written_size();
378
379 let data_prefix_len = packet_len_size + data_len_size;
380
381 buf.put_bytes(0, data_prefix_len);
382 buf.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
383
384 let mut front = &mut buf[start_len..];
385
386 VarInt(packet_len as i32).encode(&mut front)?;
387 VarInt(0).encode(front)?;
389 }
390
391 Ok(())
392}