valence_protocol/
bit_set.rs

1use std::fmt;
2use std::io::Write;
3
4use crate::{Decode, Encode};
5
6// TODO: when better const exprs are available, compute BYTE_COUNT from
7// BIT_COUNT.
8#[derive(Copy, Clone, PartialEq, Eq)]
9pub struct FixedBitSet<const BIT_COUNT: usize, const BYTE_COUNT: usize>(pub [u8; BYTE_COUNT]);
10
11impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> FixedBitSet<BIT_COUNT, BYTE_COUNT> {
12    pub fn bit(&self, idx: usize) -> bool {
13        check_counts(BIT_COUNT, BYTE_COUNT);
14        assert!(
15            idx < BIT_COUNT,
16            "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
17        );
18
19        (self.0[idx / 8] >> (idx % 8)) & 1 == 1
20    }
21
22    pub fn set_bit(&mut self, idx: usize, val: bool) {
23        check_counts(BIT_COUNT, BYTE_COUNT);
24        assert!(
25            idx < BIT_COUNT,
26            "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
27        );
28
29        let byte = &mut self.0[idx / 8];
30        *byte |= u8::from(val) << (idx % 8);
31    }
32}
33
34impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Encode
35    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
36{
37    fn encode(&self, w: impl Write) -> anyhow::Result<()> {
38        check_counts(BIT_COUNT, BYTE_COUNT);
39        self.0.encode(w)
40    }
41}
42
43impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Decode<'_>
44    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
45{
46    fn decode(r: &mut &'_ [u8]) -> anyhow::Result<Self> {
47        check_counts(BIT_COUNT, BYTE_COUNT);
48        Ok(Self(Decode::decode(r)?))
49    }
50}
51
52const fn check_counts(bits: usize, bytes: usize) {
53    assert!(bits.div_ceil(8) == bytes)
54}
55
56impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Debug
57    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
58{
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        fmt::Display::fmt(self, f)
61    }
62}
63
64impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Display
65    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
66{
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        write!(f, "0b")?;
69
70        for i in (0..BIT_COUNT).rev() {
71            if self.bit(i) {
72                write!(f, "1")?;
73            } else {
74                write!(f, "0")?;
75            }
76        }
77
78        Ok(())
79    }
80}
81
82/// 😔
83macro_rules! impl_default {
84    ($($N:literal)*) => {
85        $(
86            impl<const BIT_COUNT: usize> Default for FixedBitSet<BIT_COUNT, $N> {
87                fn default() -> Self {
88                    Self(Default::default())
89                }
90            }
91        )*
92    }
93}
94
95impl_default!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16);
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn fixed_bit_set_ops() {
103        let mut bits = FixedBitSet::<20, 3>::default();
104
105        assert!(!bits.bit(5));
106        bits.set_bit(5, true);
107        assert!(bits.bit(5));
108        assert_eq!(bits.0, [0b00100000, 0, 0]);
109    }
110
111    #[test]
112    #[should_panic]
113    fn fixed_bit_set_out_of_range() {
114        let mut bits = FixedBitSet::<20, 3>::default();
115
116        bits.set_bit(20, true);
117    }
118
119    #[test]
120    fn display_fixed_bit_set() {
121        let mut bits = FixedBitSet::<20, 3>::default();
122        bits.set_bit(5, true);
123
124        assert_eq!(format!("{bits}"), "0b00000000000000100000");
125    }
126}