1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
use std::ops::{Add, Sub};

use glam::DVec3;

/// A three-dimensional axis-aligned bounding box, or "AABB".
///
/// The AABB is defined by two points—`min` and `max`. `min` is less than or
/// equal to `max` componentwise.
#[derive(Copy, Clone, PartialEq, Default, Debug)]
pub struct Aabb {
    min: DVec3,
    max: DVec3,
}

impl Aabb {
    pub const ZERO: Self = Self {
        min: DVec3::ZERO,
        max: DVec3::ZERO,
    };

    /// Constructs a new AABB from `min` and `max` points.
    ///
    /// # Panics
    ///
    /// Panics if `debug_assertions` are enabled and `min` is not less than or
    /// equal to `max` componentwise.
    #[cfg_attr(debug_assertions, track_caller)]
    pub fn new(min: DVec3, max: DVec3) -> Self {
        debug_assert!(
            min.x <= max.x && min.y <= max.y && min.z <= max.z,
            "`min` must be less than or equal to `max` componentwise (min = {min}, max = {max})"
        );

        Self { min, max }
    }

    // TODO: remove when the assertion in `new` can be done in a `const` context.
    #[doc(hidden)]
    pub const fn new_unchecked(min: DVec3, max: DVec3) -> Self {
        Self { min, max }
    }

    /// Returns a new AABB containing a single point `p`.
    pub fn new_point(p: DVec3) -> Self {
        Self::new(p, p)
    }

    pub fn from_bottom_size(bottom: DVec3, size: DVec3) -> Self {
        Self::new(
            DVec3 {
                x: bottom.x - size.x / 2.0,
                y: bottom.y,
                z: bottom.z - size.z / 2.0,
            },
            DVec3 {
                x: bottom.x + size.x / 2.0,
                y: bottom.y + size.y,
                z: bottom.z + size.z / 2.0,
            },
        )
    }

    pub const fn min(self) -> DVec3 {
        self.min
    }

    pub const fn max(self) -> DVec3 {
        self.max
    }

    pub fn union(self, other: Self) -> Self {
        Self::new(self.min.min(other.min), self.max.max(other.max))
    }

    pub fn intersects(self, other: Self) -> bool {
        self.max.x >= other.min.x
            && other.max.x >= self.min.x
            && self.max.y >= other.min.y
            && other.max.y >= self.min.y
            && self.max.z >= other.min.z
            && other.max.z >= self.min.z
    }

    /// Does this bounding box contain the given point?
    pub fn contains_point(self, p: DVec3) -> bool {
        self.min.x <= p.x
            && self.min.y <= p.y
            && self.min.z <= p.z
            && self.max.x >= p.x
            && self.max.y >= p.y
            && self.max.z >= p.z
    }

    /// Returns the closest point in the AABB to the given point.
    pub fn projected_point(self, p: DVec3) -> DVec3 {
        p.clamp(self.min, self.max)
    }

    /// Returns the smallest distance from the AABB to the point.
    pub fn distance_to_point(self, p: DVec3) -> f64 {
        self.projected_point(p).distance(p)
    }

    /// Calculates the intersection between this AABB and a ray
    /// defined by its `origin` point and `direction` vector.
    ///
    /// If an intersection occurs, `Some([near, far])` is returned. `near` and
    /// `far` are the values of `t` in the equation `origin + t * direction =
    /// point` where `point` is the nearest or furthest intersection point to
    /// the `origin`. If no intersection occurs, then `None` is returned.
    ///
    /// In other words, if `direction` is normalized, then `near` and `far` are
    /// the distances to the nearest and furthest intersection points.
    pub fn ray_intersection(self, origin: DVec3, direction: DVec3) -> Option<[f64; 2]> {
        let mut near: f64 = 0.0;
        let mut far = f64::INFINITY;

        for i in 0..3 {
            // Rust's definition of `min` and `max` properly handle the NaNs these
            // computations may produce.
            let t0 = (self.min[i] - origin[i]) / direction[i];
            let t1 = (self.max[i] - origin[i]) / direction[i];

            near = near.max(t0.min(t1));
            far = far.min(t0.max(t1));
        }

        (near <= far).then_some([near, far])
    }
}

impl Add<DVec3> for Aabb {
    type Output = Aabb;

    fn add(self, rhs: DVec3) -> Self::Output {
        Self::new(self.min + rhs, self.max + rhs)
    }
}

impl Add<Aabb> for DVec3 {
    type Output = Aabb;

    fn add(self, rhs: Aabb) -> Self::Output {
        rhs + self
    }
}

impl Sub<DVec3> for Aabb {
    type Output = Aabb;

    fn sub(self, rhs: DVec3) -> Self::Output {
        Self::new(self.min - rhs, self.max - rhs)
    }
}

impl Sub<Aabb> for DVec3 {
    type Output = Aabb;

    fn sub(self, rhs: Aabb) -> Self::Output {
        rhs - self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn ray_intersect_edge_cases() {
        let bb = Aabb::new([0.0, 0.0, 0.0].into(), [1.0, 1.0, 1.0].into());

        let ros = [
            // On a corner
            DVec3::new(0.0, 0.0, 0.0),
            // Outside
            DVec3::new(-0.5, 0.5, -0.5),
            // In the center
            DVec3::new(0.5, 0.5, 0.5),
            // On an edge
            DVec3::new(0.0, 0.5, 0.0),
            // On a face
            DVec3::new(0.0, 0.5, 0.5),
            // Outside slabs
            DVec3::new(-2.0, -2.0, -2.0),
        ];

        let rds = [
            DVec3::new(1.0, 0.0, 0.0),
            DVec3::new(-1.0, 0.0, 0.0),
            DVec3::new(0.0, 1.0, 0.0),
            DVec3::new(0.0, -1.0, 0.0),
            DVec3::new(0.0, 0.0, 1.0),
            DVec3::new(0.0, 0.0, -1.0),
        ];

        assert!(rds.iter().all(|d| d.is_normalized()));

        for ro in ros {
            for rd in rds {
                if let Some([near, far]) = bb.ray_intersection(ro, rd) {
                    assert!(near.is_finite());
                    assert!(far.is_finite());
                    assert!(near <= far);
                    assert!(near >= 0.0);
                    assert!(far >= 0.0);
                }
            }
        }
    }
}