valence_protocol/
encode.rs

1use std::io::Write;
2
3#[cfg(feature = "encryption")]
4use aes::cipher::generic_array::GenericArray;
5#[cfg(feature = "encryption")]
6use aes::cipher::{BlockEncryptMut, BlockSizeUser, KeyIvInit};
7use anyhow::ensure;
8use bytes::{BufMut, BytesMut};
9use tracing::warn;
10
11use crate::var_int::VarInt;
12use crate::{CompressionThreshold, Encode, Packet, MAX_PACKET_SIZE};
13
14/// The AES block cipher with a 128 bit key, using the CFB-8 mode of
15/// operation.
16#[cfg(feature = "encryption")]
17type Cipher = cfb8::Encryptor<aes::Aes128>;
18
19#[derive(Default)]
20pub struct PacketEncoder {
21    buf: BytesMut,
22    #[cfg(feature = "compression")]
23    compress_buf: Vec<u8>,
24    #[cfg(feature = "compression")]
25    threshold: CompressionThreshold,
26    #[cfg(feature = "encryption")]
27    cipher: Option<Cipher>,
28}
29
30impl PacketEncoder {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[inline]
36    pub fn append_bytes(&mut self, bytes: &[u8]) {
37        self.buf.extend_from_slice(bytes)
38    }
39
40    pub fn prepend_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
41    where
42        P: Packet + Encode,
43    {
44        let start_len = self.buf.len();
45        self.append_packet(pkt)?;
46
47        let end_len = self.buf.len();
48        let total_packet_len = end_len - start_len;
49
50        // 1) Move everything back by the length of the packet.
51        // 2) Move the packet to the new space at the front.
52        // 3) Truncate the old packet away.
53        self.buf.put_bytes(0, total_packet_len);
54        self.buf.copy_within(..end_len, total_packet_len);
55        self.buf.copy_within(total_packet_len + start_len.., 0);
56        self.buf.truncate(end_len);
57
58        Ok(())
59    }
60
61    #[allow(clippy::needless_borrows_for_generic_args)]
62    pub fn append_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
63    where
64        P: Packet + Encode,
65    {
66        let start_len = self.buf.len();
67
68        pkt.encode_with_id((&mut self.buf).writer())?;
69
70        let data_len = self.buf.len() - start_len;
71
72        #[cfg(feature = "compression")]
73        if self.threshold.0 >= 0 {
74            use std::io::Read;
75
76            use flate2::bufread::ZlibEncoder;
77            use flate2::Compression;
78
79            if data_len > self.threshold.0 as usize {
80                let mut z = ZlibEncoder::new(&self.buf[start_len..], Compression::new(4));
81
82                self.compress_buf.clear();
83
84                let data_len_size = VarInt(data_len as i32).written_size();
85
86                let packet_len = data_len_size + z.read_to_end(&mut self.compress_buf)?;
87
88                ensure!(
89                    packet_len <= MAX_PACKET_SIZE as usize,
90                    "packet exceeds maximum length"
91                );
92
93                drop(z);
94
95                self.buf.truncate(start_len);
96
97                let mut writer = (&mut self.buf).writer();
98
99                VarInt(packet_len as i32).encode(&mut writer)?;
100                VarInt(data_len as i32).encode(&mut writer)?;
101                self.buf.extend_from_slice(&self.compress_buf);
102            } else {
103                let data_len_size = 1;
104                let packet_len = data_len_size + data_len;
105
106                ensure!(
107                    packet_len <= MAX_PACKET_SIZE as usize,
108                    "packet exceeds maximum length"
109                );
110
111                let packet_len_size = VarInt(packet_len as i32).written_size();
112
113                let data_prefix_len = packet_len_size + data_len_size;
114
115                self.buf.put_bytes(0, data_prefix_len);
116                self.buf
117                    .copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
118
119                let mut front = &mut self.buf[start_len..];
120
121                VarInt(packet_len as i32).encode(&mut front)?;
122                // Zero for no compression on this packet.
123                VarInt(0).encode(front)?;
124            }
125
126            return Ok(());
127        }
128
129        let packet_len = data_len;
130
131        ensure!(
132            packet_len <= MAX_PACKET_SIZE as usize,
133            "packet exceeds maximum length"
134        );
135
136        let packet_len_size = VarInt(packet_len as i32).written_size();
137
138        self.buf.put_bytes(0, packet_len_size);
139        self.buf
140            .copy_within(start_len..start_len + data_len, start_len + packet_len_size);
141
142        let front = &mut self.buf[start_len..];
143        VarInt(packet_len as i32).encode(front)?;
144
145        Ok(())
146    }
147
148    /// Takes all the packets written so far and encrypts them if encryption is
149    /// enabled.
150    pub fn take(&mut self) -> BytesMut {
151        #[cfg(feature = "encryption")]
152        if let Some(cipher) = &mut self.cipher {
153            for chunk in self.buf.chunks_mut(Cipher::block_size()) {
154                let gen_arr = GenericArray::from_mut_slice(chunk);
155                cipher.encrypt_block_mut(gen_arr);
156            }
157        }
158
159        self.buf.split()
160    }
161
162    pub fn clear(&mut self) {
163        self.buf.clear();
164    }
165
166    #[cfg(feature = "compression")]
167    pub fn set_compression(&mut self, threshold: CompressionThreshold) {
168        self.threshold = threshold;
169    }
170
171    /// Initializes the cipher with the given key. All future packets **and any
172    /// that have not been [taken] yet** are encrypted.
173    ///
174    /// [taken]: Self::take
175    ///
176    /// # Panics
177    ///
178    /// Panics if encryption is already enabled.
179    #[cfg(feature = "encryption")]
180    pub fn enable_encryption(&mut self, key: &[u8; 16]) {
181        assert!(self.cipher.is_none(), "encryption is already enabled");
182        self.cipher = Some(Cipher::new_from_slices(key, key).expect("invalid key"));
183    }
184}
185
186/// Types that can have packets written to them.
187pub trait WritePacket {
188    /// Writes a packet to this object. Encoding errors are typically logged and
189    /// discarded.
190    fn write_packet<P>(&mut self, packet: &P)
191    where
192        P: Packet + Encode,
193    {
194        if let Err(e) = self.write_packet_fallible(packet) {
195            warn!("failed to write packet '{}': {e:#}", P::NAME);
196        }
197    }
198
199    /// Writes a packet to this object. The result of encoding the packet is
200    /// returned.
201    fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
202    where
203        P: Packet + Encode;
204
205    /// Copies raw packet data directly into this object. Don't use this unless
206    /// you know what you're doing.
207    fn write_packet_bytes(&mut self, bytes: &[u8]);
208}
209
210impl<W: WritePacket> WritePacket for &mut W {
211    fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
212    where
213        P: Packet + Encode,
214    {
215        (*self).write_packet_fallible(packet)
216    }
217
218    fn write_packet_bytes(&mut self, bytes: &[u8]) {
219        (*self).write_packet_bytes(bytes)
220    }
221}
222
223impl<T: WritePacket> WritePacket for bevy_ecs::world::Mut<'_, T> {
224    fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
225    where
226        P: Packet + Encode,
227    {
228        self.as_mut().write_packet_fallible(packet)
229    }
230
231    fn write_packet_bytes(&mut self, bytes: &[u8]) {
232        self.as_mut().write_packet_bytes(bytes)
233    }
234}
235
236/// An implementor of [`WritePacket`] backed by a `Vec` mutable reference.
237///
238/// Packets are written by appending to the contained vec. If an error occurs
239/// while writing, the written bytes are truncated away.
240#[derive(Debug)]
241pub struct PacketWriter<'a> {
242    pub buf: &'a mut Vec<u8>,
243    pub threshold: CompressionThreshold,
244}
245
246impl<'a> PacketWriter<'a> {
247    pub fn new(buf: &'a mut Vec<u8>, threshold: CompressionThreshold) -> Self {
248        Self { buf, threshold }
249    }
250}
251
252impl WritePacket for PacketWriter<'_> {
253    #[cfg_attr(not(feature = "compression"), track_caller)]
254    fn write_packet_fallible<P>(&mut self, pkt: &P) -> anyhow::Result<()>
255    where
256        P: Packet + Encode,
257    {
258        let start = self.buf.len();
259
260        let res;
261
262        if self.threshold.0 >= 0 {
263            #[cfg(feature = "compression")]
264            {
265                res = encode_packet_compressed(self.buf, pkt, self.threshold.0 as u32);
266            }
267
268            #[cfg(not(feature = "compression"))]
269            {
270                panic!("\"compression\" feature must be enabled to write compressed packets");
271            }
272        } else {
273            res = encode_packet(self.buf, pkt)
274        };
275
276        if res.is_err() {
277            self.buf.truncate(start);
278        }
279
280        res
281    }
282
283    fn write_packet_bytes(&mut self, bytes: &[u8]) {
284        if let Err(e) = self.buf.write_all(bytes) {
285            warn!("failed to write packet bytes: {e:#}");
286        }
287    }
288}
289
290impl WritePacket for PacketEncoder {
291    fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
292    where
293        P: Packet + Encode,
294    {
295        self.append_packet(packet)
296    }
297
298    fn write_packet_bytes(&mut self, bytes: &[u8]) {
299        self.append_bytes(bytes)
300    }
301}
302
303fn encode_packet<P>(buf: &mut Vec<u8>, pkt: &P) -> anyhow::Result<()>
304where
305    P: Packet + Encode,
306{
307    let start_len = buf.len();
308
309    pkt.encode_with_id(&mut *buf)?;
310
311    let packet_len = buf.len() - start_len;
312
313    ensure!(
314        packet_len <= MAX_PACKET_SIZE as usize,
315        "packet exceeds maximum length"
316    );
317
318    let packet_len_size = VarInt(packet_len as i32).written_size();
319
320    buf.put_bytes(0, packet_len_size);
321    buf.copy_within(
322        start_len..start_len + packet_len,
323        start_len + packet_len_size,
324    );
325
326    let front = &mut buf[start_len..];
327    VarInt(packet_len as i32).encode(front)?;
328
329    Ok(())
330}
331
332#[cfg(feature = "compression")]
333#[allow(clippy::needless_borrows_for_generic_args)]
334fn encode_packet_compressed<P>(buf: &mut Vec<u8>, pkt: &P, threshold: u32) -> anyhow::Result<()>
335where
336    P: Packet + Encode,
337{
338    use std::io::Read;
339
340    use flate2::bufread::ZlibEncoder;
341    use flate2::Compression;
342
343    let start_len = buf.len();
344
345    pkt.encode_with_id(&mut *buf)?;
346
347    let data_len = buf.len() - start_len;
348
349    if data_len > threshold as usize {
350        let mut z = ZlibEncoder::new(&buf[start_len..], Compression::new(4));
351
352        let mut scratch = vec![];
353
354        let packet_len = VarInt(data_len as i32).written_size() + z.read_to_end(&mut scratch)?;
355
356        ensure!(
357            packet_len <= MAX_PACKET_SIZE as usize,
358            "packet exceeds maximum length"
359        );
360
361        drop(z);
362
363        buf.truncate(start_len);
364
365        VarInt(packet_len as i32).encode(&mut *buf)?;
366        VarInt(data_len as i32).encode(&mut *buf)?;
367        buf.extend_from_slice(&scratch);
368    } else {
369        let data_len_size = 1;
370        let packet_len = data_len_size + data_len;
371
372        ensure!(
373            packet_len <= MAX_PACKET_SIZE as usize,
374            "packet exceeds maximum length"
375        );
376
377        let packet_len_size = VarInt(packet_len as i32).written_size();
378
379        let data_prefix_len = packet_len_size + data_len_size;
380
381        buf.put_bytes(0, data_prefix_len);
382        buf.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
383
384        let mut front = &mut buf[start_len..];
385
386        VarInt(packet_len as i32).encode(&mut front)?;
387        // Zero for no compression on this packet.
388        VarInt(0).encode(front)?;
389    }
390
391    Ok(())
392}