java_string/
validations.rs

1use std::ops::{Bound, Range, RangeBounds, RangeTo};
2
3use crate::{JavaStr, Utf8Error};
4
5pub(crate) const TAG_CONT: u8 = 0b1000_0000;
6pub(crate) const TAG_TWO_B: u8 = 0b1100_0000;
7pub(crate) const TAG_THREE_B: u8 = 0b1110_0000;
8pub(crate) const TAG_FOUR_B: u8 = 0b1111_0000;
9pub(crate) const CONT_MASK: u8 = 0b0011_1111;
10
11#[inline]
12const fn utf8_first_byte(byte: u8, width: u32) -> u32 {
13    (byte & (0x7f >> width)) as u32
14}
15
16#[inline]
17const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 {
18    (ch << 6) | (byte & CONT_MASK) as u32
19}
20
21#[inline]
22const fn utf8_is_cont_byte(byte: u8) -> bool {
23    (byte as i8) < -64
24}
25
26/// # Safety
27///
28/// `bytes` must produce a semi-valid UTF-8 string
29#[inline]
30pub(crate) unsafe fn next_code_point<'a, I: Iterator<Item = &'a u8>>(bytes: &mut I) -> Option<u32> {
31    // Decode UTF-8
32    let x = *bytes.next()?;
33    if x < 128 {
34        return Some(x.into());
35    }
36
37    // Multibyte case follows
38    // Decode from a byte combination out of: [[[x y] z] w]
39    // NOTE: Performance is sensitive to the exact formulation here
40    let init = utf8_first_byte(x, 2);
41    // SAFETY: `bytes` produces an UTF-8-like string,
42    // so the iterator must produce a value here.
43    let y = unsafe { *bytes.next().unwrap_unchecked() };
44    let mut ch = utf8_acc_cont_byte(init, y);
45    if x >= 0xe0 {
46        // [[x y z] w] case
47        // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid
48        // SAFETY: `bytes` produces an UTF-8-like string,
49        // so the iterator must produce a value here.
50        let z = unsafe { *bytes.next().unwrap_unchecked() };
51        let y_z = utf8_acc_cont_byte((y & CONT_MASK).into(), z);
52        ch = (init << 12) | y_z;
53        if x >= 0xf0 {
54            // [x y z w] case
55            // use only the lower 3 bits of `init`
56            // SAFETY: `bytes` produces an UTF-8-like string,
57            // so the iterator must produce a value here.
58            let w = unsafe { *bytes.next().unwrap_unchecked() };
59            ch = ((init & 7) << 18) | utf8_acc_cont_byte(y_z, w);
60        }
61    }
62
63    Some(ch)
64}
65
66/// # Safety
67///
68/// `bytes` must produce a semi-valid UTF-8 string
69#[inline]
70pub(crate) unsafe fn next_code_point_reverse<'a, I: DoubleEndedIterator<Item = &'a u8>>(
71    bytes: &mut I,
72) -> Option<u32> {
73    // Decode UTF-8
74    let w = match *bytes.next_back()? {
75        next_byte if next_byte < 128 => return Some(next_byte.into()),
76        back_byte => back_byte,
77    };
78
79    // Multibyte case follows
80    // Decode from a byte combination out of: [x [y [z w]]]
81    let mut ch;
82    // SAFETY: `bytes` produces an UTF-8-like string,
83    // so the iterator must produce a value here.
84    let z = unsafe { *bytes.next_back().unwrap_unchecked() };
85    ch = utf8_first_byte(z, 2);
86    if utf8_is_cont_byte(z) {
87        // SAFETY: `bytes` produces an UTF-8-like string,
88        // so the iterator must produce a value here.
89        let y = unsafe { *bytes.next_back().unwrap_unchecked() };
90        ch = utf8_first_byte(y, 3);
91        if utf8_is_cont_byte(y) {
92            // SAFETY: `bytes` produces an UTF-8-like string,
93            // so the iterator must produce a value here.
94            let x = unsafe { *bytes.next_back().unwrap_unchecked() };
95            ch = utf8_first_byte(x, 4);
96            ch = utf8_acc_cont_byte(ch, y);
97        }
98        ch = utf8_acc_cont_byte(ch, z);
99    }
100    ch = utf8_acc_cont_byte(ch, w);
101
102    Some(ch)
103}
104
105#[inline(always)]
106pub(crate) fn run_utf8_semi_validation(v: &[u8]) -> Result<(), Utf8Error> {
107    let mut index = 0;
108    let len = v.len();
109
110    let usize_bytes = std::mem::size_of::<usize>();
111    let ascii_block_size = 2 * usize_bytes;
112    let blocks_end = if len >= ascii_block_size {
113        len - ascii_block_size + 1
114    } else {
115        0
116    };
117    let align = v.as_ptr().align_offset(usize_bytes);
118
119    while index < len {
120        let old_offset = index;
121        macro_rules! err {
122            ($error_len:expr) => {
123                return Err(Utf8Error {
124                    valid_up_to: old_offset,
125                    error_len: $error_len,
126                })
127            };
128        }
129
130        macro_rules! next {
131            () => {{
132                index += 1;
133                // we needed data, but there was none: error!
134                if index >= len {
135                    err!(None)
136                }
137                v[index]
138            }};
139        }
140
141        let first = v[index];
142        if first >= 128 {
143            let w = utf8_char_width(first);
144            // 2-byte encoding is for codepoints  \u{0080} to  \u{07ff}
145            //        first  C2 80        last DF BF
146            // 3-byte encoding is for codepoints  \u{0800} to  \u{ffff}
147            //        first  E0 A0 80     last EF BF BF
148            //   INCLUDING surrogates codepoints  \u{d800} to  \u{dfff}
149            //               ED A0 80 to       ED BF BF
150            // 4-byte encoding is for codepoints \u{1000}0 to \u{10ff}ff
151            //        first  F0 90 80 80  last F4 8F BF BF
152            //
153            // Use the UTF-8 syntax from the RFC
154            //
155            // https://tools.ietf.org/html/rfc3629
156            // UTF8-1      = %x00-7F
157            // UTF8-2      = %xC2-DF UTF8-tail
158            // UTF8-3      = %xE0 %xA0-BF UTF8-tail / %xE1-EC 2( UTF8-tail ) /
159            //               %xED %x80-9F UTF8-tail / %xEE-EF 2( UTF8-tail )
160            // UTF8-4      = %xF0 %x90-BF 2( UTF8-tail ) / %xF1-F3 3( UTF8-tail ) /
161            //               %xF4 %x80-8F 2( UTF8-tail )
162            match w {
163                2 => {
164                    if next!() as i8 >= -64 {
165                        err!(Some(1))
166                    }
167                }
168                3 => {
169                    match (first, next!()) {
170                        (0xe0, 0xa0..=0xbf) | (0xe1..=0xef, 0x80..=0xbf) => {} /* INCLUDING surrogate codepoints here */
171                        _ => err!(Some(1)),
172                    }
173                    if next!() as i8 >= -64 {
174                        err!(Some(2))
175                    }
176                }
177                4 => {
178                    match (first, next!()) {
179                        (0xf0, 0x90..=0xbf) | (0xf1..=0xf3, 0x80..=0xbf) | (0xf4, 0x80..=0x8f) => {}
180                        _ => err!(Some(1)),
181                    }
182                    if next!() as i8 >= -64 {
183                        err!(Some(2))
184                    }
185                    if next!() as i8 >= -64 {
186                        err!(Some(3))
187                    }
188                }
189                _ => err!(Some(1)),
190            }
191            index += 1;
192        } else {
193            // Ascii case, try to skip forward quickly.
194            // When the pointer is aligned, read 2 words of data per iteration
195            // until we find a word containing a non-ascii byte.
196            if align != usize::MAX && align.wrapping_sub(index) % usize_bytes == 0 {
197                let ptr = v.as_ptr();
198                while index < blocks_end {
199                    // SAFETY: since `align - index` and `ascii_block_size` are
200                    // multiples of `usize_bytes`, `block = ptr.add(index)` is
201                    // always aligned with a `usize` so it's safe to dereference
202                    // both `block` and `block.add(1)`.
203                    unsafe {
204                        let block = ptr.add(index) as *const usize;
205                        // break if there is a nonascii byte
206                        let zu = contains_nonascii(*block);
207                        let zv = contains_nonascii(*block.add(1));
208                        if zu || zv {
209                            break;
210                        }
211                    }
212                    index += ascii_block_size;
213                }
214                // step from the point where the wordwise loop stopped
215                while index < len && v[index] < 128 {
216                    index += 1;
217                }
218            } else {
219                index += 1;
220            }
221        }
222    }
223
224    Ok(())
225}
226
227#[inline(always)]
228pub(crate) const fn run_utf8_full_validation_from_semi(v: &[u8]) -> Result<(), Utf8Error> {
229    // this function checks for surrogate codepoints, between \u{d800} to \u{dfff},
230    // or ED A0 80 to ED BF BF of width 3 unicode chars. The valid range of width 3
231    // characters is ED 80 80 to ED BF BF, so we need to check for an ED byte
232    // followed by a >=A0 byte.
233    let mut index = 0;
234    while index + 3 <= v.len() {
235        if v[index] == 0xed && v[index + 1] >= 0xa0 {
236            return Err(Utf8Error {
237                valid_up_to: index,
238                error_len: Some(1),
239            });
240        }
241        index += 1;
242    }
243
244    Ok(())
245}
246
247#[inline]
248pub(crate) const fn utf8_char_width(first_byte: u8) -> usize {
249    const UTF8_CHAR_WIDTH: [u8; 256] = [
250        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
251        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
252        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
253        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
254        1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
256        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
257        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
258        4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
259    ];
260
261    UTF8_CHAR_WIDTH[first_byte as usize] as usize
262}
263
264#[inline]
265const fn contains_nonascii(x: usize) -> bool {
266    const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; std::mem::size_of::<usize>()]);
267    (x & NONASCII_MASK) != 0
268}
269
270#[cold]
271#[track_caller]
272pub(crate) fn slice_error_fail(s: &JavaStr, begin: usize, end: usize) -> ! {
273    const MAX_DISPLAY_LENGTH: usize = 256;
274    let trunc_len = s.floor_char_boundary(MAX_DISPLAY_LENGTH);
275    let s_trunc = &s[..trunc_len];
276    let ellipsis = if trunc_len < s.len() { "[...]" } else { "" };
277
278    // 1. out of bounds
279    if begin > s.len() || end > s.len() {
280        let oob_index = if begin > s.len() { begin } else { end };
281        panic!("byte index {oob_index} is out of bounds of `{s_trunc}`{ellipsis}");
282    }
283
284    // 2. begin <= end
285    assert!(
286        begin <= end,
287        "begin <= end ({begin} <= {end}) when slicing `{s_trunc}`{ellipsis}",
288    );
289
290    // 3. character boundary
291    let index = if !s.is_char_boundary(begin) {
292        begin
293    } else {
294        end
295    };
296    // find the character
297    let char_start = s.floor_char_boundary(index);
298    // `char_start` must be less than len and a char boundary
299    let ch = s[char_start..].chars().next().unwrap();
300    let char_range = char_start..char_start + ch.len_utf8();
301    panic!(
302        "byte index {index} is not a char boundary; it is inside {ch:?} (bytes {char_range:?}) of \
303         `{s_trunc}`{ellipsis}",
304    );
305}
306
307#[cold]
308#[track_caller]
309pub(crate) fn str_end_index_len_fail(index: usize, len: usize) -> ! {
310    panic!("range end index {index} out of range for JavaStr of length {len}");
311}
312
313#[cold]
314#[track_caller]
315pub(crate) fn str_index_order_fail(index: usize, end: usize) -> ! {
316    panic!("JavaStr index starts at {index} but ends at {end}");
317}
318
319#[cold]
320#[track_caller]
321pub(crate) fn str_start_index_overflow_fail() -> ! {
322    panic!("attempted to index JavaStr from after maximum usize");
323}
324
325#[cold]
326#[track_caller]
327pub(crate) fn str_end_index_overflow_fail() -> ! {
328    panic!("attempted to index JavaStr up to maximum usize")
329}
330
331#[inline]
332#[track_caller]
333pub(crate) fn to_range_checked<R>(range: R, bounds: RangeTo<usize>) -> Range<usize>
334where
335    R: RangeBounds<usize>,
336{
337    let len = bounds.end;
338
339    let start = range.start_bound();
340    let start = match start {
341        Bound::Included(&start) => start,
342        Bound::Excluded(start) => start
343            .checked_add(1)
344            .unwrap_or_else(|| str_start_index_overflow_fail()),
345        Bound::Unbounded => 0,
346    };
347
348    let end: Bound<&usize> = range.end_bound();
349    let end = match end {
350        Bound::Included(end) => end
351            .checked_add(1)
352            .unwrap_or_else(|| str_end_index_overflow_fail()),
353        Bound::Excluded(&end) => end,
354        Bound::Unbounded => len,
355    };
356
357    if start > end {
358        str_index_order_fail(start, end);
359    }
360    if end > len {
361        str_end_index_len_fail(end, len);
362    }
363
364    Range { start, end }
365}