valence_network/
packet_io.rs

1use std::io::ErrorKind;
2use std::sync::Arc;
3use std::time::Instant;
4use std::{io, mem};
5
6use anyhow::bail;
7use bytes::BytesMut;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12use tracing::{debug, warn};
13use valence_protocol::CompressionThreshold;
14use valence_server::client::{ClientBundleArgs, ClientConnection, ReceivedPacket};
15use valence_server::protocol::decode::PacketFrame;
16use valence_server::protocol::{Decode, Encode, Packet, PacketDecoder, PacketEncoder};
17
18use crate::byte_channel::{byte_channel, ByteSender, TrySendError};
19use crate::{CleanupOnDrop, NewClientInfo};
20
21pub(crate) struct PacketIo {
22    stream: TcpStream,
23    enc: PacketEncoder,
24    dec: PacketDecoder,
25    frame: PacketFrame,
26}
27
28const READ_BUF_SIZE: usize = 4096;
29
30impl PacketIo {
31    pub(crate) fn new(stream: TcpStream, enc: PacketEncoder, dec: PacketDecoder) -> Self {
32        Self {
33            stream,
34            enc,
35            dec,
36            frame: PacketFrame {
37                id: -1,
38                body: BytesMut::new(),
39            },
40        }
41    }
42
43    pub(crate) async fn send_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
44    where
45        P: Packet + Encode,
46    {
47        self.enc.append_packet(pkt)?;
48        let bytes = self.enc.take();
49        self.stream.write_all(&bytes).await?;
50        Ok(())
51    }
52
53    pub(crate) async fn recv_packet<'a, P>(&'a mut self) -> anyhow::Result<P>
54    where
55        P: Packet + Decode<'a>,
56    {
57        loop {
58            if let Some(frame) = self.dec.try_next_packet()? {
59                self.frame = frame;
60
61                return self.frame.decode();
62            }
63
64            self.dec.reserve(READ_BUF_SIZE);
65            let mut buf = self.dec.take_capacity();
66
67            if self.stream.read_buf(&mut buf).await? == 0 {
68                return Err(io::Error::from(ErrorKind::UnexpectedEof).into());
69            }
70
71            // This should always be an O(1) unsplit because we reserved space earlier and
72            // the call to `read_buf` shouldn't have grown the allocation.
73            self.dec.queue_bytes(buf);
74        }
75    }
76
77    #[allow(dead_code)]
78    pub(crate) fn set_compression(&mut self, threshold: CompressionThreshold) {
79        self.enc.set_compression(threshold);
80        self.dec.set_compression(threshold);
81    }
82
83    pub(crate) fn enable_encryption(&mut self, key: &[u8; 16]) {
84        self.enc.enable_encryption(key);
85        self.dec.enable_encryption(key);
86    }
87
88    pub(crate) fn into_client_args(
89        mut self,
90        info: NewClientInfo,
91        incoming_byte_limit: usize,
92        outgoing_byte_limit: usize,
93        cleanup: CleanupOnDrop,
94    ) -> ClientBundleArgs {
95        let (incoming_sender, incoming_receiver) = flume::unbounded();
96
97        let incoming_byte_limit = incoming_byte_limit.min(Semaphore::MAX_PERMITS);
98
99        let recv_sem = Arc::new(Semaphore::new(incoming_byte_limit));
100        let recv_sem_clone = recv_sem.clone();
101
102        let (mut reader, mut writer) = self.stream.into_split();
103
104        let reader_task = tokio::spawn(async move {
105            let mut buf = BytesMut::new();
106
107            loop {
108                let frame = match self.dec.try_next_packet() {
109                    Ok(Some(frame)) => frame,
110                    Ok(None) => {
111                        // Incomplete packet. Need more data.
112
113                        buf.reserve(READ_BUF_SIZE);
114                        match reader.read_buf(&mut buf).await {
115                            Ok(0) => break, // Reader is at EOF.
116                            Ok(_) => {}
117                            Err(e) => {
118                                debug!("error reading data from stream: {e}");
119                                break;
120                            }
121                        }
122
123                        self.dec.queue_bytes(buf.split());
124
125                        continue;
126                    }
127                    Err(e) => {
128                        warn!("error decoding packet frame: {e:#}");
129                        break;
130                    }
131                };
132
133                let timestamp = Instant::now();
134
135                // Estimate memory usage of this packet.
136                let cost = mem::size_of::<ReceivedPacket>() + frame.body.len();
137
138                if cost > incoming_byte_limit {
139                    debug!(
140                        cost,
141                        incoming_byte_limit,
142                        "cost of received packet is greater than the incoming memory limit"
143                    );
144                    // We would never acquire enough permits, so we should exit instead of getting
145                    // stuck.
146                    break;
147                }
148
149                // Wait until there's enough space for this packet.
150                let Ok(permits) = recv_sem.acquire_many(cost as u32).await else {
151                    // Semaphore closed.
152                    break;
153                };
154
155                // The permits will be added back on the other side of the channel.
156                permits.forget();
157
158                let packet = ReceivedPacket {
159                    timestamp,
160                    id: frame.id,
161                    body: frame.body.freeze(),
162                };
163
164                if incoming_sender.try_send(packet).is_err() {
165                    // Channel closed.
166                    break;
167                }
168            }
169        });
170
171        let (outgoing_sender, mut outgoing_receiver) = byte_channel(outgoing_byte_limit);
172
173        let writer_task = tokio::spawn(async move {
174            loop {
175                let bytes = match outgoing_receiver.recv_async().await {
176                    Ok(bytes) => bytes,
177                    Err(e) => {
178                        debug!("error receiving packet data: {e}");
179                        break;
180                    }
181                };
182
183                if let Err(e) = writer.write_all(&bytes).await {
184                    debug!("error writing data to stream: {e}");
185                }
186            }
187        });
188
189        ClientBundleArgs {
190            username: info.username,
191            uuid: info.uuid,
192            ip: info.ip,
193            properties: info.properties.0,
194            conn: Box::new(RealClientConnection {
195                send: outgoing_sender,
196                recv: incoming_receiver,
197                recv_sem: recv_sem_clone,
198                reader_task,
199                writer_task,
200                _cleanup: cleanup,
201            }),
202            enc: self.enc,
203        }
204    }
205}
206
207struct RealClientConnection {
208    send: ByteSender,
209    recv: flume::Receiver<ReceivedPacket>,
210    /// Limits the amount of data queued in the `recv` channel. Each permit
211    /// represents one byte.
212    recv_sem: Arc<Semaphore>,
213    _cleanup: CleanupOnDrop,
214    reader_task: JoinHandle<()>,
215    writer_task: JoinHandle<()>,
216}
217
218impl ClientConnection for RealClientConnection {
219    fn try_send(&mut self, bytes: BytesMut) -> anyhow::Result<()> {
220        match self.send.try_send(bytes) {
221            Ok(()) => Ok(()),
222            Err(TrySendError::Full(_)) => bail!(
223                "reached configured outgoing limit of {} bytes",
224                self.send.limit()
225            ),
226            Err(TrySendError::Disconnected(_)) => bail!("client disconnected"),
227        }
228    }
229
230    fn try_recv(&mut self) -> anyhow::Result<Option<ReceivedPacket>> {
231        match self.recv.try_recv() {
232            Ok(packet) => {
233                let cost = mem::size_of::<ReceivedPacket>() + packet.body.len();
234
235                // Add the permits back that we removed earlier.
236                self.recv_sem.add_permits(cost);
237
238                Ok(Some(packet))
239            }
240            Err(flume::TryRecvError::Empty) => Ok(None),
241            Err(flume::TryRecvError::Disconnected) => bail!("client disconnected"),
242        }
243    }
244
245    fn len(&self) -> usize {
246        self.recv.len()
247    }
248}
249
250impl Drop for RealClientConnection {
251    fn drop(&mut self) {
252        self.writer_task.abort();
253        self.reader_task.abort();
254    }
255}