valence_protocol/
decode.rs

1#[cfg(feature = "encryption")]
2use aes::cipher::{generic_array::GenericArray, BlockDecryptMut, BlockSizeUser, KeyIvInit};
3use anyhow::{bail, ensure, Context};
4use bytes::{Buf, BytesMut};
5
6use crate::var_int::{VarInt, VarIntDecodeError};
7#[cfg(feature = "compression")]
8use crate::CompressionThreshold;
9use crate::{Decode, Packet, MAX_PACKET_SIZE};
10
11/// The AES block cipher with a 128 bit key, using the CFB-8 mode of
12/// operation.
13#[cfg(feature = "encryption")]
14type Cipher = cfb8::Decryptor<aes::Aes128>;
15
16#[derive(Default)]
17pub struct PacketDecoder {
18    buf: BytesMut,
19    #[cfg(feature = "compression")]
20    decompress_buf: BytesMut,
21    #[cfg(feature = "compression")]
22    threshold: CompressionThreshold,
23    #[cfg(feature = "encryption")]
24    cipher: Option<Cipher>,
25}
26
27impl PacketDecoder {
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    pub fn try_next_packet(&mut self) -> anyhow::Result<Option<PacketFrame>> {
33        let mut r = &self.buf[..];
34
35        let packet_len = match VarInt::decode_partial(&mut r) {
36            Ok(len) => len,
37            Err(VarIntDecodeError::Incomplete) => return Ok(None),
38            Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"),
39        };
40
41        ensure!(
42            (0..=MAX_PACKET_SIZE).contains(&packet_len),
43            "packet length of {packet_len} is out of bounds"
44        );
45
46        if r.len() < packet_len as usize {
47            // Not enough data arrived yet.
48            return Ok(None);
49        }
50
51        let packet_len_len = VarInt(packet_len).written_size();
52
53        let mut data;
54
55        #[cfg(feature = "compression")]
56        if self.threshold.0 >= 0 {
57            use std::io::Write;
58
59            use bytes::BufMut;
60            use flate2::write::ZlibDecoder;
61
62            r = &r[..packet_len as usize];
63
64            let data_len = VarInt::decode(&mut r)?.0;
65
66            ensure!(
67                (0..MAX_PACKET_SIZE).contains(&data_len),
68                "decompressed packet length of {data_len} is out of bounds"
69            );
70
71            // Is this packet compressed?
72            if data_len > 0 {
73                ensure!(
74                    data_len > self.threshold.0,
75                    "decompressed packet length of {data_len} is <= the compression threshold of \
76                     {}",
77                    self.threshold.0
78                );
79
80                debug_assert!(self.decompress_buf.is_empty());
81
82                self.decompress_buf.put_bytes(0, data_len as usize);
83
84                // TODO: use libdeflater or zune-inflate?
85                let mut z = ZlibDecoder::new(&mut self.decompress_buf[..]);
86
87                z.write_all(r)?;
88
89                ensure!(
90                    z.finish()?.is_empty(),
91                    "decompressed packet length is shorter than expected"
92                );
93
94                let total_packet_len = VarInt(packet_len).written_size() + packet_len as usize;
95
96                self.buf.advance(total_packet_len);
97
98                data = self.decompress_buf.split();
99            } else {
100                debug_assert_eq!(data_len, 0);
101
102                ensure!(
103                    r.len() <= self.threshold.0 as usize,
104                    "uncompressed packet length of {} exceeds compression threshold of {}",
105                    r.len(),
106                    self.threshold.0
107                );
108
109                let remaining_len = r.len();
110
111                self.buf.advance(packet_len_len + 1);
112
113                data = self.buf.split_to(remaining_len);
114            }
115        } else {
116            self.buf.advance(packet_len_len);
117            data = self.buf.split_to(packet_len as usize);
118        }
119
120        #[cfg(not(feature = "compression"))]
121        {
122            self.buf.advance(packet_len_len);
123            data = self.buf.split_to(packet_len as usize);
124        }
125
126        // Decode the leading packet ID.
127        r = &data[..];
128        let packet_id = VarInt::decode(&mut r)
129            .context("failed to decode packet ID")?
130            .0;
131
132        data.advance(data.len() - r.len());
133
134        Ok(Some(PacketFrame {
135            id: packet_id,
136            body: data,
137        }))
138    }
139
140    #[cfg(feature = "compression")]
141    pub fn compression(&self) -> CompressionThreshold {
142        self.threshold
143    }
144
145    #[cfg(feature = "compression")]
146    pub fn set_compression(&mut self, threshold: CompressionThreshold) {
147        self.threshold = threshold;
148    }
149
150    #[cfg(feature = "encryption")]
151    pub fn enable_encryption(&mut self, key: &[u8; 16]) {
152        assert!(self.cipher.is_none(), "encryption is already enabled");
153
154        let mut cipher = Cipher::new_from_slices(key, key).expect("invalid key");
155
156        // Don't forget to decrypt the data we already have.
157        Self::decrypt_bytes(&mut cipher, &mut self.buf);
158
159        self.cipher = Some(cipher);
160    }
161
162    /// Decrypts the provided byte slice in place using the cipher, without
163    /// consuming the cipher.
164    #[cfg(feature = "encryption")]
165    fn decrypt_bytes(cipher: &mut Cipher, bytes: &mut [u8]) {
166        for chunk in bytes.chunks_mut(Cipher::block_size()) {
167            let gen_arr = GenericArray::from_mut_slice(chunk);
168            cipher.decrypt_block_mut(gen_arr);
169        }
170    }
171
172    pub fn queue_bytes(&mut self, mut bytes: BytesMut) {
173        #![allow(unused_mut)]
174
175        #[cfg(feature = "encryption")]
176        if let Some(cipher) = &mut self.cipher {
177            Self::decrypt_bytes(cipher, &mut bytes);
178        }
179
180        self.buf.unsplit(bytes);
181    }
182
183    pub fn queue_slice(&mut self, bytes: &[u8]) {
184        #[cfg(feature = "encryption")]
185        let len = self.buf.len();
186
187        self.buf.extend_from_slice(bytes);
188
189        #[cfg(feature = "encryption")]
190        if let Some(cipher) = &mut self.cipher {
191            let slice = &mut self.buf[len..];
192            Self::decrypt_bytes(cipher, slice);
193        }
194    }
195
196    pub fn take_capacity(&mut self) -> BytesMut {
197        self.buf.split_off(self.buf.len())
198    }
199
200    pub fn reserve(&mut self, additional: usize) {
201        self.buf.reserve(additional);
202    }
203}
204
205#[derive(Clone, Debug)]
206pub struct PacketFrame {
207    /// The ID of the decoded packet.
208    pub id: i32,
209    /// The contents of the packet after the leading `VarInt` ID.
210    pub body: BytesMut,
211}
212
213impl PacketFrame {
214    /// Attempts to decode this packet as type `P`. An error is returned if the
215    /// packet ID does not match, the body of the packet failed to decode, or
216    /// some input was missed.
217    pub fn decode<'a, P>(&'a self) -> anyhow::Result<P>
218    where
219        P: Packet + Decode<'a>,
220    {
221        ensure!(
222            P::ID == self.id,
223            "packet ID mismatch while decoding '{}': expected {}, got {}",
224            P::NAME,
225            P::ID,
226            self.id
227        );
228
229        let mut r = &self.body[..];
230
231        let pkt = P::decode(&mut r)?;
232
233        ensure!(
234            r.is_empty(),
235            "missed {} bytes while decoding '{}'",
236            r.len(),
237            P::NAME
238        );
239
240        Ok(pkt)
241    }
242}