valence_math/
aabb.rs

1use std::ops::{Add, Sub};
2
3use glam::DVec3;
4
5/// A three-dimensional axis-aligned bounding box, or "AABB".
6///
7/// The AABB is defined by two points—`min` and `max`. `min` is less than or
8/// equal to `max` componentwise.
9#[derive(Copy, Clone, PartialEq, Default, Debug)]
10pub struct Aabb {
11    min: DVec3,
12    max: DVec3,
13}
14
15impl Aabb {
16    pub const ZERO: Self = Self {
17        min: DVec3::ZERO,
18        max: DVec3::ZERO,
19    };
20
21    /// Constructs a new AABB from `min` and `max` points.
22    ///
23    /// # Panics
24    ///
25    /// Panics if `debug_assertions` are enabled and `min` is not less than or
26    /// equal to `max` componentwise.
27    #[cfg_attr(debug_assertions, track_caller)]
28    pub fn new(min: DVec3, max: DVec3) -> Self {
29        debug_assert!(
30            min.x <= max.x && min.y <= max.y && min.z <= max.z,
31            "`min` must be less than or equal to `max` componentwise (min = {min}, max = {max})"
32        );
33
34        Self { min, max }
35    }
36
37    // TODO: remove when the assertion in `new` can be done in a `const` context.
38    #[doc(hidden)]
39    pub const fn new_unchecked(min: DVec3, max: DVec3) -> Self {
40        Self { min, max }
41    }
42
43    /// Returns a new AABB containing a single point `p`.
44    pub fn new_point(p: DVec3) -> Self {
45        Self::new(p, p)
46    }
47
48    pub fn from_bottom_size(bottom: DVec3, size: DVec3) -> Self {
49        Self::new(
50            DVec3 {
51                x: bottom.x - size.x / 2.0,
52                y: bottom.y,
53                z: bottom.z - size.z / 2.0,
54            },
55            DVec3 {
56                x: bottom.x + size.x / 2.0,
57                y: bottom.y + size.y,
58                z: bottom.z + size.z / 2.0,
59            },
60        )
61    }
62
63    pub const fn min(self) -> DVec3 {
64        self.min
65    }
66
67    pub const fn max(self) -> DVec3 {
68        self.max
69    }
70
71    pub fn union(self, other: Self) -> Self {
72        Self::new(self.min.min(other.min), self.max.max(other.max))
73    }
74
75    pub fn intersects(self, other: Self) -> bool {
76        self.max.x >= other.min.x
77            && other.max.x >= self.min.x
78            && self.max.y >= other.min.y
79            && other.max.y >= self.min.y
80            && self.max.z >= other.min.z
81            && other.max.z >= self.min.z
82    }
83
84    /// Does this bounding box contain the given point?
85    pub fn contains_point(self, p: DVec3) -> bool {
86        self.min.x <= p.x
87            && self.min.y <= p.y
88            && self.min.z <= p.z
89            && self.max.x >= p.x
90            && self.max.y >= p.y
91            && self.max.z >= p.z
92    }
93
94    /// Returns the closest point in the AABB to the given point.
95    pub fn projected_point(self, p: DVec3) -> DVec3 {
96        p.clamp(self.min, self.max)
97    }
98
99    /// Returns the smallest distance from the AABB to the point.
100    pub fn distance_to_point(self, p: DVec3) -> f64 {
101        self.projected_point(p).distance(p)
102    }
103
104    /// Calculates the intersection between this AABB and a ray
105    /// defined by its `origin` point and `direction` vector.
106    ///
107    /// If an intersection occurs, `Some([near, far])` is returned. `near` and
108    /// `far` are the values of `t` in the equation `origin + t * direction =
109    /// point` where `point` is the nearest or furthest intersection point to
110    /// the `origin`. If no intersection occurs, then `None` is returned.
111    ///
112    /// In other words, if `direction` is normalized, then `near` and `far` are
113    /// the distances to the nearest and furthest intersection points.
114    pub fn ray_intersection(self, origin: DVec3, direction: DVec3) -> Option<[f64; 2]> {
115        let mut near: f64 = 0.0;
116        let mut far = f64::INFINITY;
117
118        for i in 0..3 {
119            // Rust's definition of `min` and `max` properly handle the NaNs these
120            // computations may produce.
121            let t0 = (self.min[i] - origin[i]) / direction[i];
122            let t1 = (self.max[i] - origin[i]) / direction[i];
123
124            near = near.max(t0.min(t1));
125            far = far.min(t0.max(t1));
126        }
127
128        (near <= far).then_some([near, far])
129    }
130}
131
132impl Add<DVec3> for Aabb {
133    type Output = Aabb;
134
135    fn add(self, rhs: DVec3) -> Self::Output {
136        Self::new(self.min + rhs, self.max + rhs)
137    }
138}
139
140impl Add<Aabb> for DVec3 {
141    type Output = Aabb;
142
143    fn add(self, rhs: Aabb) -> Self::Output {
144        rhs + self
145    }
146}
147
148impl Sub<DVec3> for Aabb {
149    type Output = Aabb;
150
151    fn sub(self, rhs: DVec3) -> Self::Output {
152        Self::new(self.min - rhs, self.max - rhs)
153    }
154}
155
156impl Sub<Aabb> for DVec3 {
157    type Output = Aabb;
158
159    fn sub(self, rhs: Aabb) -> Self::Output {
160        rhs - self
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn ray_intersect_edge_cases() {
170        let bb = Aabb::new([0.0, 0.0, 0.0].into(), [1.0, 1.0, 1.0].into());
171
172        let ros = [
173            // On a corner
174            DVec3::new(0.0, 0.0, 0.0),
175            // Outside
176            DVec3::new(-0.5, 0.5, -0.5),
177            // In the center
178            DVec3::new(0.5, 0.5, 0.5),
179            // On an edge
180            DVec3::new(0.0, 0.5, 0.0),
181            // On a face
182            DVec3::new(0.0, 0.5, 0.5),
183            // Outside slabs
184            DVec3::new(-2.0, -2.0, -2.0),
185        ];
186
187        let rds = [
188            DVec3::new(1.0, 0.0, 0.0),
189            DVec3::new(-1.0, 0.0, 0.0),
190            DVec3::new(0.0, 1.0, 0.0),
191            DVec3::new(0.0, -1.0, 0.0),
192            DVec3::new(0.0, 0.0, 1.0),
193            DVec3::new(0.0, 0.0, -1.0),
194        ];
195
196        assert!(rds.iter().all(|d| d.is_normalized()));
197
198        for ro in ros {
199            for rd in rds {
200                if let Some([near, far]) = bb.ray_intersection(ro, rd) {
201                    assert!(near.is_finite());
202                    assert!(far.is_finite());
203                    assert!(near <= far);
204                    assert!(near >= 0.0);
205                    assert!(far >= 0.0);
206                }
207            }
208        }
209    }
210}