1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use std::fmt;
use std::io::Write;

use crate::{Decode, Encode};

// TODO: when better const exprs are available, compute BYTE_COUNT from
// BIT_COUNT.
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct FixedBitSet<const BIT_COUNT: usize, const BYTE_COUNT: usize>(pub [u8; BYTE_COUNT]);

impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> FixedBitSet<BIT_COUNT, BYTE_COUNT> {
    pub fn bit(&self, idx: usize) -> bool {
        check_counts(BIT_COUNT, BYTE_COUNT);
        assert!(
            idx < BIT_COUNT,
            "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
        );

        self.0[idx / 8] >> (idx % 8) & 1 == 1
    }

    pub fn set_bit(&mut self, idx: usize, val: bool) {
        check_counts(BIT_COUNT, BYTE_COUNT);
        assert!(
            idx < BIT_COUNT,
            "bit index of {idx} out of range for bitset with {BIT_COUNT} bits"
        );

        let byte = &mut self.0[idx / 8];
        *byte |= u8::from(val) << (idx % 8);
    }
}

impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Encode
    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
{
    fn encode(&self, w: impl Write) -> anyhow::Result<()> {
        check_counts(BIT_COUNT, BYTE_COUNT);
        self.0.encode(w)
    }
}

impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> Decode<'_>
    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
{
    fn decode(r: &mut &'_ [u8]) -> anyhow::Result<Self> {
        check_counts(BIT_COUNT, BYTE_COUNT);
        Ok(Self(Decode::decode(r)?))
    }
}

const fn check_counts(bits: usize, bytes: usize) {
    assert!((bits + 7) / 8 == bytes)
}

impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Debug
    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Display::fmt(self, f)
    }
}

impl<const BIT_COUNT: usize, const BYTE_COUNT: usize> fmt::Display
    for FixedBitSet<BIT_COUNT, BYTE_COUNT>
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "0b")?;

        for i in (0..BIT_COUNT).rev() {
            if self.bit(i) {
                write!(f, "1")?;
            } else {
                write!(f, "0")?;
            }
        }

        Ok(())
    }
}

/// 😔
macro_rules! impl_default {
    ($($N:literal)*) => {
        $(
            impl<const BIT_COUNT: usize> Default for FixedBitSet<BIT_COUNT, $N> {
                fn default() -> Self {
                    Self(Default::default())
                }
            }
        )*
    }
}

impl_default!(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16);

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fixed_bit_set_ops() {
        let mut bits = FixedBitSet::<20, 3>::default();

        assert!(!bits.bit(5));
        bits.set_bit(5, true);
        assert!(bits.bit(5));
        assert_eq!(bits.0, [0b00100000, 0, 0]);
    }

    #[test]
    #[should_panic]
    fn fixed_bit_set_out_of_range() {
        let mut bits = FixedBitSet::<20, 3>::default();

        bits.set_bit(20, true);
    }

    #[test]
    fn display_fixed_bit_set() {
        let mut bits = FixedBitSet::<20, 3>::default();
        bits.set_bit(5, true);

        assert_eq!(format!("{bits}"), "0b00000000000000100000");
    }
}