valence_network/
packet_io.rs
1use std::io::ErrorKind;
2use std::sync::Arc;
3use std::time::Instant;
4use std::{io, mem};
5
6use anyhow::bail;
7use bytes::BytesMut;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12use tracing::{debug, warn};
13use valence_protocol::CompressionThreshold;
14use valence_server::client::{ClientBundleArgs, ClientConnection, ReceivedPacket};
15use valence_server::protocol::decode::PacketFrame;
16use valence_server::protocol::{Decode, Encode, Packet, PacketDecoder, PacketEncoder};
17
18use crate::byte_channel::{byte_channel, ByteSender, TrySendError};
19use crate::{CleanupOnDrop, NewClientInfo};
20
21pub(crate) struct PacketIo {
22 stream: TcpStream,
23 enc: PacketEncoder,
24 dec: PacketDecoder,
25 frame: PacketFrame,
26}
27
28const READ_BUF_SIZE: usize = 4096;
29
30impl PacketIo {
31 pub(crate) fn new(stream: TcpStream, enc: PacketEncoder, dec: PacketDecoder) -> Self {
32 Self {
33 stream,
34 enc,
35 dec,
36 frame: PacketFrame {
37 id: -1,
38 body: BytesMut::new(),
39 },
40 }
41 }
42
43 pub(crate) async fn send_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
44 where
45 P: Packet + Encode,
46 {
47 self.enc.append_packet(pkt)?;
48 let bytes = self.enc.take();
49 self.stream.write_all(&bytes).await?;
50 Ok(())
51 }
52
53 pub(crate) async fn recv_packet<'a, P>(&'a mut self) -> anyhow::Result<P>
54 where
55 P: Packet + Decode<'a>,
56 {
57 loop {
58 if let Some(frame) = self.dec.try_next_packet()? {
59 self.frame = frame;
60
61 return self.frame.decode();
62 }
63
64 self.dec.reserve(READ_BUF_SIZE);
65 let mut buf = self.dec.take_capacity();
66
67 if self.stream.read_buf(&mut buf).await? == 0 {
68 return Err(io::Error::from(ErrorKind::UnexpectedEof).into());
69 }
70
71 self.dec.queue_bytes(buf);
74 }
75 }
76
77 #[allow(dead_code)]
78 pub(crate) fn set_compression(&mut self, threshold: CompressionThreshold) {
79 self.enc.set_compression(threshold);
80 self.dec.set_compression(threshold);
81 }
82
83 pub(crate) fn enable_encryption(&mut self, key: &[u8; 16]) {
84 self.enc.enable_encryption(key);
85 self.dec.enable_encryption(key);
86 }
87
88 pub(crate) fn into_client_args(
89 mut self,
90 info: NewClientInfo,
91 incoming_byte_limit: usize,
92 outgoing_byte_limit: usize,
93 cleanup: CleanupOnDrop,
94 ) -> ClientBundleArgs {
95 let (incoming_sender, incoming_receiver) = flume::unbounded();
96
97 let incoming_byte_limit = incoming_byte_limit.min(Semaphore::MAX_PERMITS);
98
99 let recv_sem = Arc::new(Semaphore::new(incoming_byte_limit));
100 let recv_sem_clone = recv_sem.clone();
101
102 let (mut reader, mut writer) = self.stream.into_split();
103
104 let reader_task = tokio::spawn(async move {
105 let mut buf = BytesMut::new();
106
107 loop {
108 let frame = match self.dec.try_next_packet() {
109 Ok(Some(frame)) => frame,
110 Ok(None) => {
111 buf.reserve(READ_BUF_SIZE);
114 match reader.read_buf(&mut buf).await {
115 Ok(0) => break, Ok(_) => {}
117 Err(e) => {
118 debug!("error reading data from stream: {e}");
119 break;
120 }
121 }
122
123 self.dec.queue_bytes(buf.split());
124
125 continue;
126 }
127 Err(e) => {
128 warn!("error decoding packet frame: {e:#}");
129 break;
130 }
131 };
132
133 let timestamp = Instant::now();
134
135 let cost = mem::size_of::<ReceivedPacket>() + frame.body.len();
137
138 if cost > incoming_byte_limit {
139 debug!(
140 cost,
141 incoming_byte_limit,
142 "cost of received packet is greater than the incoming memory limit"
143 );
144 break;
147 }
148
149 let Ok(permits) = recv_sem.acquire_many(cost as u32).await else {
151 break;
153 };
154
155 permits.forget();
157
158 let packet = ReceivedPacket {
159 timestamp,
160 id: frame.id,
161 body: frame.body.freeze(),
162 };
163
164 if incoming_sender.try_send(packet).is_err() {
165 break;
167 }
168 }
169 });
170
171 let (outgoing_sender, mut outgoing_receiver) = byte_channel(outgoing_byte_limit);
172
173 let writer_task = tokio::spawn(async move {
174 loop {
175 let bytes = match outgoing_receiver.recv_async().await {
176 Ok(bytes) => bytes,
177 Err(e) => {
178 debug!("error receiving packet data: {e}");
179 break;
180 }
181 };
182
183 if let Err(e) = writer.write_all(&bytes).await {
184 debug!("error writing data to stream: {e}");
185 }
186 }
187 });
188
189 ClientBundleArgs {
190 username: info.username,
191 uuid: info.uuid,
192 ip: info.ip,
193 properties: info.properties.0,
194 conn: Box::new(RealClientConnection {
195 send: outgoing_sender,
196 recv: incoming_receiver,
197 recv_sem: recv_sem_clone,
198 reader_task,
199 writer_task,
200 _cleanup: cleanup,
201 }),
202 enc: self.enc,
203 }
204 }
205}
206
207struct RealClientConnection {
208 send: ByteSender,
209 recv: flume::Receiver<ReceivedPacket>,
210 recv_sem: Arc<Semaphore>,
213 _cleanup: CleanupOnDrop,
214 reader_task: JoinHandle<()>,
215 writer_task: JoinHandle<()>,
216}
217
218impl ClientConnection for RealClientConnection {
219 fn try_send(&mut self, bytes: BytesMut) -> anyhow::Result<()> {
220 match self.send.try_send(bytes) {
221 Ok(()) => Ok(()),
222 Err(TrySendError::Full(_)) => bail!(
223 "reached configured outgoing limit of {} bytes",
224 self.send.limit()
225 ),
226 Err(TrySendError::Disconnected(_)) => bail!("client disconnected"),
227 }
228 }
229
230 fn try_recv(&mut self) -> anyhow::Result<Option<ReceivedPacket>> {
231 match self.recv.try_recv() {
232 Ok(packet) => {
233 let cost = mem::size_of::<ReceivedPacket>() + packet.body.len();
234
235 self.recv_sem.add_permits(cost);
237
238 Ok(Some(packet))
239 }
240 Err(flume::TryRecvError::Empty) => Ok(None),
241 Err(flume::TryRecvError::Disconnected) => bail!("client disconnected"),
242 }
243 }
244
245 fn len(&self) -> usize {
246 self.recv.len()
247 }
248}
249
250impl Drop for RealClientConnection {
251 fn drop(&mut self) {
252 self.writer_task.abort();
253 self.reader_task.abort();
254 }
255}