1use std::borrow::Cow;
2use std::hash::Hash;
3use std::{fmt, mem};
4
5use byteorder::{BigEndian, ReadBytesExt};
6
7use crate::tag::Tag;
8use crate::{Compound, Error, List, Result, Value};
9
10pub fn from_binary<'de, S>(slice: &mut &'de [u8]) -> Result<(Compound<S>, S)>
15where
16 S: FromModifiedUtf8<'de> + Hash + Ord,
17{
18 let mut state = DecodeState { slice, depth: 0 };
19
20 let root_tag = state.read_tag()?;
21
22 if root_tag != Tag::Compound {
23 return Err(Error::new_owned(format!(
24 "expected root tag for compound (got {})",
25 root_tag.name(),
26 )));
27 }
28
29 let root_name = state.read_string::<S>()?;
30 let root = state.read_compound()?;
31
32 debug_assert_eq!(state.depth, 0);
33
34 Ok((root, root_name))
35}
36
37const MAX_DEPTH: usize = 512;
39
40struct DecodeState<'a, 'de> {
41 slice: &'a mut &'de [u8],
42 depth: usize,
44}
45
46impl<'de> DecodeState<'_, 'de> {
47 #[inline]
48 fn check_depth<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T>) -> Result<T> {
49 if self.depth >= MAX_DEPTH {
50 return Err(Error::new_static("reached maximum recursion depth"));
51 }
52
53 self.depth += 1;
54 let res = f(self);
55 self.depth -= 1;
56 res
57 }
58
59 fn read_tag(&mut self) -> Result<Tag> {
60 match self.slice.read_u8()? {
61 0 => Ok(Tag::End),
62 1 => Ok(Tag::Byte),
63 2 => Ok(Tag::Short),
64 3 => Ok(Tag::Int),
65 4 => Ok(Tag::Long),
66 5 => Ok(Tag::Float),
67 6 => Ok(Tag::Double),
68 7 => Ok(Tag::ByteArray),
69 8 => Ok(Tag::String),
70 9 => Ok(Tag::List),
71 10 => Ok(Tag::Compound),
72 11 => Ok(Tag::IntArray),
73 12 => Ok(Tag::LongArray),
74 byte => Err(Error::new_owned(format!("invalid tag byte of {byte:#x}"))),
75 }
76 }
77
78 fn read_value<S>(&mut self, tag: Tag) -> Result<Value<S>>
79 where
80 S: FromModifiedUtf8<'de> + Hash + Ord,
81 {
82 match tag {
83 Tag::End => unreachable!("illegal TAG_End argument"),
84 Tag::Byte => Ok(self.read_byte()?.into()),
85 Tag::Short => Ok(self.read_short()?.into()),
86 Tag::Int => Ok(self.read_int()?.into()),
87 Tag::Long => Ok(self.read_long()?.into()),
88 Tag::Float => Ok(self.read_float()?.into()),
89 Tag::Double => Ok(self.read_double()?.into()),
90 Tag::ByteArray => Ok(self.read_byte_array()?.into()),
91 Tag::String => Ok(Value::String(self.read_string::<S>()?)),
92 Tag::List => self.check_depth(|st| Ok(st.read_any_list::<S>()?.into())),
93 Tag::Compound => self.check_depth(|st| Ok(st.read_compound::<S>()?.into())),
94 Tag::IntArray => Ok(self.read_int_array()?.into()),
95 Tag::LongArray => Ok(self.read_long_array()?.into()),
96 }
97 }
98
99 fn read_byte(&mut self) -> Result<i8> {
100 Ok(self.slice.read_i8()?)
101 }
102
103 fn read_short(&mut self) -> Result<i16> {
104 Ok(self.slice.read_i16::<BigEndian>()?)
105 }
106
107 fn read_int(&mut self) -> Result<i32> {
108 Ok(self.slice.read_i32::<BigEndian>()?)
109 }
110
111 fn read_long(&mut self) -> Result<i64> {
112 Ok(self.slice.read_i64::<BigEndian>()?)
113 }
114
115 fn read_float(&mut self) -> Result<f32> {
116 Ok(self.slice.read_f32::<BigEndian>()?)
117 }
118
119 fn read_double(&mut self) -> Result<f64> {
120 Ok(self.slice.read_f64::<BigEndian>()?)
121 }
122
123 fn read_byte_array(&mut self) -> Result<Vec<i8>> {
124 let len = self.slice.read_i32::<BigEndian>()?;
125
126 if len.is_negative() {
127 return Err(Error::new_owned(format!(
128 "negative byte array length of {len}"
129 )));
130 }
131
132 if len as usize > self.slice.len() {
133 return Err(Error::new_owned(format!(
134 "byte array length of {len} exceeds remainder of input"
135 )));
136 }
137
138 let (left, right) = self.slice.split_at(len as usize);
139
140 let array = left.iter().map(|b| *b as i8).collect();
141 *self.slice = right;
142
143 Ok(array)
144 }
145
146 fn read_string<S>(&mut self) -> Result<S>
147 where
148 S: FromModifiedUtf8<'de>,
149 {
150 let len = self.slice.read_u16::<BigEndian>()?.into();
151
152 if len > self.slice.len() {
153 return Err(Error::new_owned(format!(
154 "string of length {len} exceeds remainder of input"
155 )));
156 }
157
158 let (left, right) = self.slice.split_at(len);
159
160 match S::from_modified_utf8(left) {
161 Ok(str) => {
162 *self.slice = right;
163 Ok(str)
164 }
165 Err(_) => Err(Error::new_static("could not decode modified UTF-8 data")),
166 }
167 }
168
169 fn read_any_list<S>(&mut self) -> Result<List<S>>
170 where
171 S: FromModifiedUtf8<'de> + Hash + Ord,
172 {
173 match self.read_tag()? {
174 Tag::End => match self.read_int()? {
175 0 => Ok(List::End),
176 len => Err(Error::new_owned(format!(
177 "TAG_End list with nonzero length of {len}"
178 ))),
179 },
180 Tag::Byte => Ok(self.read_list(Tag::Byte, 1, |st| st.read_byte())?.into()),
181 Tag::Short => Ok(self.read_list(Tag::Short, 2, |st| st.read_short())?.into()),
182 Tag::Int => Ok(self.read_list(Tag::Int, 4, |st| st.read_int())?.into()),
183 Tag::Long => Ok(self.read_list(Tag::Long, 8, |st| st.read_long())?.into()),
184 Tag::Float => Ok(self.read_list(Tag::Float, 4, |st| st.read_float())?.into()),
185 Tag::Double => Ok(self
186 .read_list(Tag::Double, 8, |st| st.read_double())?
187 .into()),
188 Tag::ByteArray => Ok(self
189 .read_list(Tag::ByteArray, 0, |st| st.read_byte_array())?
190 .into()),
191 Tag::String => Ok(List::String(
192 self.read_list(Tag::String, 0, |st| st.read_string::<S>())?,
193 )),
194 Tag::List => self.check_depth(|st| {
195 Ok(st
196 .read_list(Tag::List, 0, |st| st.read_any_list::<S>())?
197 .into())
198 }),
199 Tag::Compound => self.check_depth(|st| {
200 Ok(st
201 .read_list(Tag::Compound, 0, |st| st.read_compound::<S>())?
202 .into())
203 }),
204 Tag::IntArray => Ok(self
205 .read_list(Tag::IntArray, 0, |st| st.read_int_array())?
206 .into()),
207 Tag::LongArray => Ok(self
208 .read_list(Tag::LongArray, 0, |st| st.read_long_array())?
209 .into()),
210 }
211 }
212
213 #[inline]
217 fn read_list<T, F>(
218 &mut self,
219 elem_type: Tag,
220 elem_size: usize,
221 mut read_elem: F,
222 ) -> Result<Vec<T>>
223 where
224 F: FnMut(&mut Self) -> Result<T>,
225 {
226 let len = self.read_int()?;
227
228 if len.is_negative() {
229 return Err(Error::new_owned(format!(
230 "negative {} list length of {len}",
231 elem_type.name()
232 )));
233 }
234
235 if len as u64 * elem_size as u64 > self.slice.len() as u64 {
238 return Err(Error::new_owned(format!(
239 "{} list of length {len} exceeds remainder of input",
240 elem_type.name()
241 )));
242 }
243
244 let mut list = Vec::with_capacity(if elem_size == 0 { 0 } else { len as usize });
245
246 for _ in 0..len {
247 list.push(read_elem(self)?);
248 }
249
250 Ok(list)
251 }
252
253 fn read_compound<S>(&mut self) -> Result<Compound<S>>
254 where
255 S: FromModifiedUtf8<'de> + Hash + Ord,
256 {
257 let mut compound = Compound::new();
258
259 loop {
260 let tag = self.read_tag()?;
261 if tag == Tag::End {
262 return Ok(compound);
263 }
264
265 compound.insert(self.read_string::<S>()?, self.read_value::<S>(tag)?);
266 }
267 }
268
269 fn read_int_array(&mut self) -> Result<Vec<i32>> {
270 let len = self.read_int()?;
271
272 if len.is_negative() {
273 return Err(Error::new_owned(format!(
274 "negative int array length of {len}",
275 )));
276 }
277
278 if len as u64 * mem::size_of::<i32>() as u64 > self.slice.len() as u64 {
279 return Err(Error::new_owned(format!(
280 "int array of length {len} exceeds remainder of input"
281 )));
282 }
283
284 let mut array = Vec::with_capacity(len as usize);
285 for _ in 0..len {
286 array.push(self.read_int()?);
287 }
288
289 Ok(array)
290 }
291
292 fn read_long_array(&mut self) -> Result<Vec<i64>> {
293 let len = self.read_int()?;
294
295 if len.is_negative() {
296 return Err(Error::new_owned(format!(
297 "negative long array length of {len}",
298 )));
299 }
300
301 if len as u64 * mem::size_of::<i64>() as u64 > self.slice.len() as u64 {
302 return Err(Error::new_owned(format!(
303 "long array of length {len} exceeds remainder of input"
304 )));
305 }
306
307 let mut array = Vec::with_capacity(len as usize);
308 for _ in 0..len {
309 array.push(self.read_long()?);
310 }
311
312 Ok(array)
313 }
314}
315
316#[derive(Copy, Clone, Debug)]
317pub struct FromModifiedUtf8Error;
318
319impl fmt::Display for FromModifiedUtf8Error {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 f.write_str("could not decode modified UTF-8 data")
322 }
323}
324
325impl std::error::Error for FromModifiedUtf8Error {}
326
327pub trait FromModifiedUtf8<'de>: Sized {
329 fn from_modified_utf8(
330 modified_utf8: &'de [u8],
331 ) -> std::result::Result<Self, FromModifiedUtf8Error>;
332}
333
334impl<'de> FromModifiedUtf8<'de> for Cow<'de, str> {
335 fn from_modified_utf8(
336 modified_utf8: &'de [u8],
337 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
338 cesu8::from_java_cesu8(modified_utf8).map_err(move |_| FromModifiedUtf8Error)
339 }
340}
341
342impl<'de> FromModifiedUtf8<'de> for String {
343 fn from_modified_utf8(
344 modified_utf8: &'de [u8],
345 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
346 match cesu8::from_java_cesu8(modified_utf8) {
347 Ok(str) => Ok(str.into_owned()),
348 Err(_) => Err(FromModifiedUtf8Error),
349 }
350 }
351}
352
353#[cfg(feature = "java_string")]
354impl<'de> FromModifiedUtf8<'de> for Cow<'de, java_string::JavaStr> {
355 fn from_modified_utf8(
356 modified_utf8: &'de [u8],
357 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
358 java_string::JavaStr::from_modified_utf8(modified_utf8).map_err(|_| FromModifiedUtf8Error)
359 }
360}
361
362#[cfg(feature = "java_string")]
363impl<'de> FromModifiedUtf8<'de> for java_string::JavaString {
364 fn from_modified_utf8(
365 modified_utf8: &'de [u8],
366 ) -> std::result::Result<Self, FromModifiedUtf8Error> {
367 match java_string::JavaStr::from_modified_utf8(modified_utf8) {
368 Ok(str) => Ok(str.into_owned()),
369 Err(_) => Err(FromModifiedUtf8Error),
370 }
371 }
372}