use std::io::Write;
#[cfg(feature = "encryption")]
use aes::cipher::generic_array::GenericArray;
#[cfg(feature = "encryption")]
use aes::cipher::{BlockEncryptMut, BlockSizeUser, KeyIvInit};
use anyhow::ensure;
use bytes::{BufMut, BytesMut};
use tracing::warn;
use crate::var_int::VarInt;
use crate::{CompressionThreshold, Encode, Packet, MAX_PACKET_SIZE};
#[cfg(feature = "encryption")]
type Cipher = cfb8::Encryptor<aes::Aes128>;
#[derive(Default)]
pub struct PacketEncoder {
buf: BytesMut,
#[cfg(feature = "compression")]
compress_buf: Vec<u8>,
#[cfg(feature = "compression")]
threshold: CompressionThreshold,
#[cfg(feature = "encryption")]
cipher: Option<Cipher>,
}
impl PacketEncoder {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn append_bytes(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes)
}
pub fn prepend_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
let start_len = self.buf.len();
self.append_packet(pkt)?;
let end_len = self.buf.len();
let total_packet_len = end_len - start_len;
self.buf.put_bytes(0, total_packet_len);
self.buf.copy_within(..end_len, total_packet_len);
self.buf.copy_within(total_packet_len + start_len.., 0);
self.buf.truncate(end_len);
Ok(())
}
#[allow(clippy::needless_borrows_for_generic_args)]
pub fn append_packet<P>(&mut self, pkt: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
let start_len = self.buf.len();
pkt.encode_with_id((&mut self.buf).writer())?;
let data_len = self.buf.len() - start_len;
#[cfg(feature = "compression")]
if self.threshold.0 >= 0 {
use std::io::Read;
use flate2::bufread::ZlibEncoder;
use flate2::Compression;
if data_len > self.threshold.0 as usize {
let mut z = ZlibEncoder::new(&self.buf[start_len..], Compression::new(4));
self.compress_buf.clear();
let data_len_size = VarInt(data_len as i32).written_size();
let packet_len = data_len_size + z.read_to_end(&mut self.compress_buf)?;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
drop(z);
self.buf.truncate(start_len);
let mut writer = (&mut self.buf).writer();
VarInt(packet_len as i32).encode(&mut writer)?;
VarInt(data_len as i32).encode(&mut writer)?;
self.buf.extend_from_slice(&self.compress_buf);
} else {
let data_len_size = 1;
let packet_len = data_len_size + data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
let data_prefix_len = packet_len_size + data_len_size;
self.buf.put_bytes(0, data_prefix_len);
self.buf
.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
let mut front = &mut self.buf[start_len..];
VarInt(packet_len as i32).encode(&mut front)?;
VarInt(0).encode(front)?;
}
return Ok(());
}
let packet_len = data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
self.buf.put_bytes(0, packet_len_size);
self.buf
.copy_within(start_len..start_len + data_len, start_len + packet_len_size);
let front = &mut self.buf[start_len..];
VarInt(packet_len as i32).encode(front)?;
Ok(())
}
pub fn take(&mut self) -> BytesMut {
#[cfg(feature = "encryption")]
if let Some(cipher) = &mut self.cipher {
for chunk in self.buf.chunks_mut(Cipher::block_size()) {
let gen_arr = GenericArray::from_mut_slice(chunk);
cipher.encrypt_block_mut(gen_arr);
}
}
self.buf.split()
}
pub fn clear(&mut self) {
self.buf.clear();
}
#[cfg(feature = "compression")]
pub fn set_compression(&mut self, threshold: CompressionThreshold) {
self.threshold = threshold;
}
#[cfg(feature = "encryption")]
pub fn enable_encryption(&mut self, key: &[u8; 16]) {
assert!(self.cipher.is_none(), "encryption is already enabled");
self.cipher = Some(Cipher::new_from_slices(key, key).expect("invalid key"));
}
}
pub trait WritePacket {
fn write_packet<P>(&mut self, packet: &P)
where
P: Packet + Encode,
{
if let Err(e) = self.write_packet_fallible(packet) {
warn!("failed to write packet '{}': {e:#}", P::NAME);
}
}
fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
where
P: Packet + Encode;
fn write_packet_bytes(&mut self, bytes: &[u8]);
}
impl<W: WritePacket> WritePacket for &mut W {
fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
(*self).write_packet_fallible(packet)
}
fn write_packet_bytes(&mut self, bytes: &[u8]) {
(*self).write_packet_bytes(bytes)
}
}
impl<T: WritePacket> WritePacket for bevy_ecs::world::Mut<'_, T> {
fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
self.as_mut().write_packet_fallible(packet)
}
fn write_packet_bytes(&mut self, bytes: &[u8]) {
self.as_mut().write_packet_bytes(bytes)
}
}
#[derive(Debug)]
pub struct PacketWriter<'a> {
pub buf: &'a mut Vec<u8>,
pub threshold: CompressionThreshold,
}
impl<'a> PacketWriter<'a> {
pub fn new(buf: &'a mut Vec<u8>, threshold: CompressionThreshold) -> Self {
Self { buf, threshold }
}
}
impl WritePacket for PacketWriter<'_> {
#[cfg_attr(not(feature = "compression"), track_caller)]
fn write_packet_fallible<P>(&mut self, pkt: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
let start = self.buf.len();
let res;
if self.threshold.0 >= 0 {
#[cfg(feature = "compression")]
{
res = encode_packet_compressed(self.buf, pkt, self.threshold.0 as u32);
}
#[cfg(not(feature = "compression"))]
{
panic!("\"compression\" feature must be enabled to write compressed packets");
}
} else {
res = encode_packet(self.buf, pkt)
};
if res.is_err() {
self.buf.truncate(start);
}
res
}
fn write_packet_bytes(&mut self, bytes: &[u8]) {
if let Err(e) = self.buf.write_all(bytes) {
warn!("failed to write packet bytes: {e:#}");
}
}
}
impl WritePacket for PacketEncoder {
fn write_packet_fallible<P>(&mut self, packet: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
self.append_packet(packet)
}
fn write_packet_bytes(&mut self, bytes: &[u8]) {
self.append_bytes(bytes)
}
}
fn encode_packet<P>(buf: &mut Vec<u8>, pkt: &P) -> anyhow::Result<()>
where
P: Packet + Encode,
{
let start_len = buf.len();
pkt.encode_with_id(&mut *buf)?;
let packet_len = buf.len() - start_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
buf.put_bytes(0, packet_len_size);
buf.copy_within(
start_len..start_len + packet_len,
start_len + packet_len_size,
);
let front = &mut buf[start_len..];
VarInt(packet_len as i32).encode(front)?;
Ok(())
}
#[cfg(feature = "compression")]
#[allow(clippy::needless_borrows_for_generic_args)]
fn encode_packet_compressed<P>(buf: &mut Vec<u8>, pkt: &P, threshold: u32) -> anyhow::Result<()>
where
P: Packet + Encode,
{
use std::io::Read;
use flate2::bufread::ZlibEncoder;
use flate2::Compression;
let start_len = buf.len();
pkt.encode_with_id(&mut *buf)?;
let data_len = buf.len() - start_len;
if data_len > threshold as usize {
let mut z = ZlibEncoder::new(&buf[start_len..], Compression::new(4));
let mut scratch = vec![];
let packet_len = VarInt(data_len as i32).written_size() + z.read_to_end(&mut scratch)?;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
drop(z);
buf.truncate(start_len);
VarInt(packet_len as i32).encode(&mut *buf)?;
VarInt(data_len as i32).encode(&mut *buf)?;
buf.extend_from_slice(&scratch);
} else {
let data_len_size = 1;
let packet_len = data_len_size + data_len;
ensure!(
packet_len <= MAX_PACKET_SIZE as usize,
"packet exceeds maximum length"
);
let packet_len_size = VarInt(packet_len as i32).written_size();
let data_prefix_len = packet_len_size + data_len_size;
buf.put_bytes(0, data_prefix_len);
buf.copy_within(start_len..start_len + data_len, start_len + data_prefix_len);
let mut front = &mut buf[start_len..];
VarInt(packet_len as i32).encode(&mut front)?;
VarInt(0).encode(front)?;
}
Ok(())
}