packet_inspector/
packet_io.rs

1use std::io;
2use std::io::ErrorKind;
3
4use anyhow::ensure;
5use bytes::{BufMut, BytesMut};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8use valence_protocol::decode::{PacketDecoder, PacketFrame};
9use valence_protocol::encode::PacketEncoder;
10use valence_protocol::{CompressionThreshold, Encode, VarInt, MAX_PACKET_SIZE};
11
12pub(crate) struct PacketIoReader {
13    reader: tokio::io::ReadHalf<tokio::net::TcpStream>,
14    dec: PacketDecoder,
15    threshold: CompressionThreshold,
16}
17
18impl PacketIoReader {
19    pub(crate) async fn recv_packet_raw(&mut self) -> anyhow::Result<PacketFrame> {
20        loop {
21            if let Some(frame) = self.dec.try_next_packet()? {
22                // self.logger
23                //     .log("Unknown".to_string(), self.direction.clone(), frame.clone());
24
25                return Ok(frame);
26            }
27
28            self.dec.reserve(READ_BUF_SIZE);
29            let mut buf = self.dec.take_capacity();
30
31            if self.reader.read_buf(&mut buf).await? == 0 {
32                return Err(io::Error::from(ErrorKind::UnexpectedEof).into());
33            }
34
35            // This should always be an O(1) unsplit because we reserved space earlier and
36            // the call to `read_buf` shouldn't have grown the allocation.
37            self.dec.queue_bytes(buf);
38        }
39    }
40
41    #[allow(dead_code)]
42    pub(crate) fn set_compression(&mut self, threshold: CompressionThreshold) {
43        self.threshold = threshold;
44        self.dec.set_compression(threshold);
45    }
46}
47
48pub(crate) struct PacketIoWriter {
49    writer: tokio::io::WriteHalf<tokio::net::TcpStream>,
50    enc: PacketEncoder,
51    threshold: CompressionThreshold,
52}
53
54impl PacketIoWriter {
55    /*
56      No  | Packet Length |  VarInt     | Length of (Data Length) + Compressed length of (Packet ID + Data)
57      No  | Data Length   |  VarInt     | Length of uncompressed (Packet ID + Data) or 0
58      Yes | Packet ID	  |  VarInt     | zlib compressed packet ID (see the sections below)
59      Yes | Data          |  Byte Array | zlib compressed packet data (see the sections below)
60    */
61    pub(crate) async fn send_packet_raw(&mut self, frame: &PacketFrame) -> anyhow::Result<()> {
62        let id_varint = VarInt(frame.id);
63        let id_buf = varint_to_bytes(id_varint);
64
65        let mut uncompressed_packet = BytesMut::new();
66        uncompressed_packet.extend_from_slice(&id_buf);
67        uncompressed_packet.extend_from_slice(&frame.body);
68        let uncompressed_packet_length = uncompressed_packet.len();
69        let uncompressed_packet_length_varint = VarInt(uncompressed_packet_length as i32);
70
71        if self.threshold.0 >= 0 {
72            if uncompressed_packet_length > self.threshold.0 as usize {
73                use std::io::Read;
74
75                use flate2::bufread::ZlibEncoder;
76                use flate2::Compression;
77
78                let mut z = ZlibEncoder::new(&uncompressed_packet[..], Compression::new(4));
79                let mut compressed = Vec::new();
80
81                let data_len_size = uncompressed_packet_length_varint.written_size();
82
83                let packet_len = data_len_size + z.read_to_end(&mut compressed)?;
84
85                ensure!(
86                    packet_len <= MAX_PACKET_SIZE as usize,
87                    "packet exceeds maximum length"
88                );
89
90                drop(z);
91
92                self.enc
93                    .append_bytes(&varint_to_bytes(VarInt(packet_len as i32)));
94
95                self.enc
96                    .append_bytes(&varint_to_bytes(uncompressed_packet_length_varint));
97
98                self.enc.append_bytes(&compressed);
99
100                let bytes = self.enc.take();
101
102                self.writer.write_all(&bytes).await?;
103                self.writer.flush().await?;
104
105                // now we need to compress the packet.
106            } else {
107                // no need to compress, but we do need to inject a zero
108                let empty = VarInt(0);
109
110                let data_len_size = empty.written_size();
111                let packet_len = data_len_size + uncompressed_packet_length;
112
113                self.enc
114                    .append_bytes(&varint_to_bytes(VarInt(packet_len as i32)));
115                self.enc.append_bytes(&varint_to_bytes(empty));
116                self.enc.append_bytes(&uncompressed_packet);
117                let bytes = self.enc.take();
118                self.writer.write_all(&bytes).await?;
119                self.writer.flush().await?;
120            }
121
122            return Ok(());
123        }
124
125        let length = varint_to_bytes(VarInt(uncompressed_packet_length as i32));
126
127        // the frame should be uncompressed at this point.
128        self.enc.append_bytes(&length);
129        self.enc.append_bytes(&uncompressed_packet);
130
131        let bytes = self.enc.take();
132
133        self.writer.write_all(&bytes).await?;
134
135        Ok(())
136    }
137
138    #[allow(dead_code)]
139    pub(crate) fn set_compression(&mut self, threshold: CompressionThreshold) {
140        self.threshold = threshold;
141        self.enc.set_compression(threshold);
142    }
143
144    pub(crate) async fn shutdown(&mut self) -> std::io::Result<()> {
145        self.writer.shutdown().await?;
146        Ok(())
147    }
148}
149
150pub(crate) struct PacketIo {
151    stream: TcpStream,
152    enc: PacketEncoder,
153    dec: PacketDecoder,
154    threshold: CompressionThreshold,
155}
156
157const READ_BUF_SIZE: usize = 1024;
158
159impl PacketIo {
160    pub(crate) fn new(stream: TcpStream) -> Self {
161        Self {
162            stream,
163            enc: PacketEncoder::new(),
164            dec: PacketDecoder::new(),
165            threshold: CompressionThreshold::DEFAULT,
166        }
167    }
168
169    pub(crate) fn split(self) -> (PacketIoReader, PacketIoWriter) {
170        let (reader, writer) = tokio::io::split(self.stream);
171
172        (
173            PacketIoReader {
174                reader,
175                dec: self.dec,
176                threshold: self.threshold,
177            },
178            PacketIoWriter {
179                writer,
180                enc: self.enc,
181                threshold: self.threshold,
182            },
183        )
184    }
185
186    #[allow(dead_code)]
187    pub(crate) async fn set_compression(&mut self, threshold: CompressionThreshold) {
188        self.threshold = threshold;
189        self.enc.set_compression(threshold);
190        self.dec.set_compression(threshold);
191    }
192}
193
194pub(crate) fn varint_to_bytes(i: VarInt) -> BytesMut {
195    let mut buf = BytesMut::new();
196    let mut writer = (&mut buf).writer();
197    i.encode(&mut writer).unwrap();
198
199    buf
200}