valence_nbt/binary/
decode.rs

1use std::borrow::Cow;
2use std::hash::Hash;
3use std::{fmt, mem};
4
5use byteorder::{BigEndian, ReadBytesExt};
6
7use crate::tag::Tag;
8use crate::{Compound, Error, List, Result, Value};
9
10/// Decodes uncompressed NBT binary data from the provided slice.
11///
12/// The string returned in the tuple is the name of the root compound
13/// (typically the empty string).
14pub fn from_binary<'de, S>(slice: &mut &'de [u8]) -> Result<(Compound<S>, S)>
15where
16    S: FromModifiedUtf8<'de> + Hash + Ord,
17{
18    let mut state = DecodeState { slice, depth: 0 };
19
20    let root_tag = state.read_tag()?;
21
22    if root_tag != Tag::Compound {
23        return Err(Error::new_owned(format!(
24            "expected root tag for compound (got {})",
25            root_tag.name(),
26        )));
27    }
28
29    let root_name = state.read_string::<S>()?;
30    let root = state.read_compound()?;
31
32    debug_assert_eq!(state.depth, 0);
33
34    Ok((root, root_name))
35}
36
37/// Maximum recursion depth to prevent overflowing the call stack.
38const MAX_DEPTH: usize = 512;
39
40struct DecodeState<'a, 'de> {
41    slice: &'a mut &'de [u8],
42    /// Current recursion depth.
43    depth: usize,
44}
45
46impl<'de> DecodeState<'_, 'de> {
47    #[inline]
48    fn check_depth<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T>) -> Result<T> {
49        if self.depth >= MAX_DEPTH {
50            return Err(Error::new_static("reached maximum recursion depth"));
51        }
52
53        self.depth += 1;
54        let res = f(self);
55        self.depth -= 1;
56        res
57    }
58
59    fn read_tag(&mut self) -> Result<Tag> {
60        match self.slice.read_u8()? {
61            0 => Ok(Tag::End),
62            1 => Ok(Tag::Byte),
63            2 => Ok(Tag::Short),
64            3 => Ok(Tag::Int),
65            4 => Ok(Tag::Long),
66            5 => Ok(Tag::Float),
67            6 => Ok(Tag::Double),
68            7 => Ok(Tag::ByteArray),
69            8 => Ok(Tag::String),
70            9 => Ok(Tag::List),
71            10 => Ok(Tag::Compound),
72            11 => Ok(Tag::IntArray),
73            12 => Ok(Tag::LongArray),
74            byte => Err(Error::new_owned(format!("invalid tag byte of {byte:#x}"))),
75        }
76    }
77
78    fn read_value<S>(&mut self, tag: Tag) -> Result<Value<S>>
79    where
80        S: FromModifiedUtf8<'de> + Hash + Ord,
81    {
82        match tag {
83            Tag::End => unreachable!("illegal TAG_End argument"),
84            Tag::Byte => Ok(self.read_byte()?.into()),
85            Tag::Short => Ok(self.read_short()?.into()),
86            Tag::Int => Ok(self.read_int()?.into()),
87            Tag::Long => Ok(self.read_long()?.into()),
88            Tag::Float => Ok(self.read_float()?.into()),
89            Tag::Double => Ok(self.read_double()?.into()),
90            Tag::ByteArray => Ok(self.read_byte_array()?.into()),
91            Tag::String => Ok(Value::String(self.read_string::<S>()?)),
92            Tag::List => self.check_depth(|st| Ok(st.read_any_list::<S>()?.into())),
93            Tag::Compound => self.check_depth(|st| Ok(st.read_compound::<S>()?.into())),
94            Tag::IntArray => Ok(self.read_int_array()?.into()),
95            Tag::LongArray => Ok(self.read_long_array()?.into()),
96        }
97    }
98
99    fn read_byte(&mut self) -> Result<i8> {
100        Ok(self.slice.read_i8()?)
101    }
102
103    fn read_short(&mut self) -> Result<i16> {
104        Ok(self.slice.read_i16::<BigEndian>()?)
105    }
106
107    fn read_int(&mut self) -> Result<i32> {
108        Ok(self.slice.read_i32::<BigEndian>()?)
109    }
110
111    fn read_long(&mut self) -> Result<i64> {
112        Ok(self.slice.read_i64::<BigEndian>()?)
113    }
114
115    fn read_float(&mut self) -> Result<f32> {
116        Ok(self.slice.read_f32::<BigEndian>()?)
117    }
118
119    fn read_double(&mut self) -> Result<f64> {
120        Ok(self.slice.read_f64::<BigEndian>()?)
121    }
122
123    fn read_byte_array(&mut self) -> Result<Vec<i8>> {
124        let len = self.slice.read_i32::<BigEndian>()?;
125
126        if len.is_negative() {
127            return Err(Error::new_owned(format!(
128                "negative byte array length of {len}"
129            )));
130        }
131
132        if len as usize > self.slice.len() {
133            return Err(Error::new_owned(format!(
134                "byte array length of {len} exceeds remainder of input"
135            )));
136        }
137
138        let (left, right) = self.slice.split_at(len as usize);
139
140        let array = left.iter().map(|b| *b as i8).collect();
141        *self.slice = right;
142
143        Ok(array)
144    }
145
146    fn read_string<S>(&mut self) -> Result<S>
147    where
148        S: FromModifiedUtf8<'de>,
149    {
150        let len = self.slice.read_u16::<BigEndian>()?.into();
151
152        if len > self.slice.len() {
153            return Err(Error::new_owned(format!(
154                "string of length {len} exceeds remainder of input"
155            )));
156        }
157
158        let (left, right) = self.slice.split_at(len);
159
160        match S::from_modified_utf8(left) {
161            Ok(str) => {
162                *self.slice = right;
163                Ok(str)
164            }
165            Err(_) => Err(Error::new_static("could not decode modified UTF-8 data")),
166        }
167    }
168
169    fn read_any_list<S>(&mut self) -> Result<List<S>>
170    where
171        S: FromModifiedUtf8<'de> + Hash + Ord,
172    {
173        match self.read_tag()? {
174            Tag::End => match self.read_int()? {
175                0 => Ok(List::End),
176                len => Err(Error::new_owned(format!(
177                    "TAG_End list with nonzero length of {len}"
178                ))),
179            },
180            Tag::Byte => Ok(self.read_list(Tag::Byte, 1, |st| st.read_byte())?.into()),
181            Tag::Short => Ok(self.read_list(Tag::Short, 2, |st| st.read_short())?.into()),
182            Tag::Int => Ok(self.read_list(Tag::Int, 4, |st| st.read_int())?.into()),
183            Tag::Long => Ok(self.read_list(Tag::Long, 8, |st| st.read_long())?.into()),
184            Tag::Float => Ok(self.read_list(Tag::Float, 4, |st| st.read_float())?.into()),
185            Tag::Double => Ok(self
186                .read_list(Tag::Double, 8, |st| st.read_double())?
187                .into()),
188            Tag::ByteArray => Ok(self
189                .read_list(Tag::ByteArray, 0, |st| st.read_byte_array())?
190                .into()),
191            Tag::String => Ok(List::String(
192                self.read_list(Tag::String, 0, |st| st.read_string::<S>())?,
193            )),
194            Tag::List => self.check_depth(|st| {
195                Ok(st
196                    .read_list(Tag::List, 0, |st| st.read_any_list::<S>())?
197                    .into())
198            }),
199            Tag::Compound => self.check_depth(|st| {
200                Ok(st
201                    .read_list(Tag::Compound, 0, |st| st.read_compound::<S>())?
202                    .into())
203            }),
204            Tag::IntArray => Ok(self
205                .read_list(Tag::IntArray, 0, |st| st.read_int_array())?
206                .into()),
207            Tag::LongArray => Ok(self
208                .read_list(Tag::LongArray, 0, |st| st.read_long_array())?
209                .into()),
210        }
211    }
212
213    /// Assumes the element tag has already been read.
214    ///
215    /// `min_elem_size` is the minimum size of the list element when encoded.
216    #[inline]
217    fn read_list<T, F>(
218        &mut self,
219        elem_type: Tag,
220        elem_size: usize,
221        mut read_elem: F,
222    ) -> Result<Vec<T>>
223    where
224        F: FnMut(&mut Self) -> Result<T>,
225    {
226        let len = self.read_int()?;
227
228        if len.is_negative() {
229            return Err(Error::new_owned(format!(
230                "negative {} list length of {len}",
231                elem_type.name()
232            )));
233        }
234
235        // Ensure we don't reserve more than the maximum amount of memory required given
236        // the size of the remaining input.
237        if len as u64 * elem_size as u64 > self.slice.len() as u64 {
238            return Err(Error::new_owned(format!(
239                "{} list of length {len} exceeds remainder of input",
240                elem_type.name()
241            )));
242        }
243
244        let mut list = Vec::with_capacity(if elem_size == 0 { 0 } else { len as usize });
245
246        for _ in 0..len {
247            list.push(read_elem(self)?);
248        }
249
250        Ok(list)
251    }
252
253    fn read_compound<S>(&mut self) -> Result<Compound<S>>
254    where
255        S: FromModifiedUtf8<'de> + Hash + Ord,
256    {
257        let mut compound = Compound::new();
258
259        loop {
260            let tag = self.read_tag()?;
261            if tag == Tag::End {
262                return Ok(compound);
263            }
264
265            compound.insert(self.read_string::<S>()?, self.read_value::<S>(tag)?);
266        }
267    }
268
269    fn read_int_array(&mut self) -> Result<Vec<i32>> {
270        let len = self.read_int()?;
271
272        if len.is_negative() {
273            return Err(Error::new_owned(format!(
274                "negative int array length of {len}",
275            )));
276        }
277
278        if len as u64 * mem::size_of::<i32>() as u64 > self.slice.len() as u64 {
279            return Err(Error::new_owned(format!(
280                "int array of length {len} exceeds remainder of input"
281            )));
282        }
283
284        let mut array = Vec::with_capacity(len as usize);
285        for _ in 0..len {
286            array.push(self.read_int()?);
287        }
288
289        Ok(array)
290    }
291
292    fn read_long_array(&mut self) -> Result<Vec<i64>> {
293        let len = self.read_int()?;
294
295        if len.is_negative() {
296            return Err(Error::new_owned(format!(
297                "negative long array length of {len}",
298            )));
299        }
300
301        if len as u64 * mem::size_of::<i64>() as u64 > self.slice.len() as u64 {
302            return Err(Error::new_owned(format!(
303                "long array of length {len} exceeds remainder of input"
304            )));
305        }
306
307        let mut array = Vec::with_capacity(len as usize);
308        for _ in 0..len {
309            array.push(self.read_long()?);
310        }
311
312        Ok(array)
313    }
314}
315
316#[derive(Copy, Clone, Debug)]
317pub struct FromModifiedUtf8Error;
318
319impl fmt::Display for FromModifiedUtf8Error {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        f.write_str("could not decode modified UTF-8 data")
322    }
323}
324
325impl std::error::Error for FromModifiedUtf8Error {}
326
327/// A string type which can be decoded from Java's [modified UTF-8](https://docs.oracle.com/javase/8/docs/api/java/io/DataInput.html#modified-utf-8).
328pub trait FromModifiedUtf8<'de>: Sized {
329    fn from_modified_utf8(
330        modified_utf8: &'de [u8],
331    ) -> std::result::Result<Self, FromModifiedUtf8Error>;
332}
333
334impl<'de> FromModifiedUtf8<'de> for Cow<'de, str> {
335    fn from_modified_utf8(
336        modified_utf8: &'de [u8],
337    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
338        cesu8::from_java_cesu8(modified_utf8).map_err(move |_| FromModifiedUtf8Error)
339    }
340}
341
342impl<'de> FromModifiedUtf8<'de> for String {
343    fn from_modified_utf8(
344        modified_utf8: &'de [u8],
345    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
346        match cesu8::from_java_cesu8(modified_utf8) {
347            Ok(str) => Ok(str.into_owned()),
348            Err(_) => Err(FromModifiedUtf8Error),
349        }
350    }
351}
352
353#[cfg(feature = "java_string")]
354impl<'de> FromModifiedUtf8<'de> for Cow<'de, java_string::JavaStr> {
355    fn from_modified_utf8(
356        modified_utf8: &'de [u8],
357    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
358        java_string::JavaStr::from_modified_utf8(modified_utf8).map_err(|_| FromModifiedUtf8Error)
359    }
360}
361
362#[cfg(feature = "java_string")]
363impl<'de> FromModifiedUtf8<'de> for java_string::JavaString {
364    fn from_modified_utf8(
365        modified_utf8: &'de [u8],
366    ) -> std::result::Result<Self, FromModifiedUtf8Error> {
367        match java_string::JavaStr::from_modified_utf8(modified_utf8) {
368            Ok(str) => Ok(str.into_owned()),
369            Err(_) => Err(FromModifiedUtf8Error),
370        }
371    }
372}