valence_spatial/
bvh.rs

1use std::iter::FusedIterator;
2use std::mem;
3
4use approx::abs_diff_eq;
5use rayon::iter::{
6    IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
7};
8use vek::{Aabb, Vec3};
9
10use crate::{ray_box_intersect, Bounded3D, RaycastHit, SpatialIndex};
11
12#[derive(Clone, Debug)]
13pub struct Bvh<T> {
14    internal_nodes: Vec<InternalNode>,
15    leaf_nodes: Vec<T>,
16    root: NodeIdx,
17}
18
19#[derive(Clone, Debug)]
20struct InternalNode {
21    bb: Aabb<f64>,
22    left: NodeIdx,
23    right: NodeIdx,
24}
25
26// TODO: we could use usize here to store more elements.
27type NodeIdx = u32;
28
29impl<T: Bounded3D + Send + Sync> Bvh<T> {
30    pub fn new() -> Self {
31        Self {
32            internal_nodes: vec![],
33            leaf_nodes: vec![],
34            root: NodeIdx::MAX,
35        }
36    }
37
38    pub fn rebuild<I: IntoIterator<Item = T>>(&mut self, leaves: I) {
39        self.internal_nodes.clear();
40        self.leaf_nodes.clear();
41
42        self.leaf_nodes.extend(leaves);
43
44        let leaf_count = self.leaf_nodes.len();
45
46        if leaf_count == 0 {
47            return;
48        }
49
50        self.internal_nodes.reserve_exact(leaf_count - 1);
51        self.internal_nodes.resize(
52            leaf_count - 1,
53            InternalNode {
54                bb: Aabb::default(),
55                left: NodeIdx::MAX,
56                right: NodeIdx::MAX,
57            },
58        );
59
60        if NodeIdx::try_from(leaf_count)
61            .ok()
62            .and_then(|count| count.checked_add(count - 1))
63            .is_none()
64        {
65            panic!("too many elements in BVH");
66        }
67
68        let id = self.leaf_nodes[0].aabb();
69        let scene_bounds = self
70            .leaf_nodes
71            .par_iter()
72            .map(|l| l.aabb())
73            .reduce(|| id, Aabb::union);
74
75        self.root = rebuild_rec(
76            0,
77            scene_bounds,
78            &mut self.internal_nodes,
79            &mut self.leaf_nodes,
80            leaf_count as NodeIdx,
81        )
82        .0;
83
84        debug_assert_eq!(self.internal_nodes.len(), self.leaf_nodes.len() - 1);
85    }
86
87    pub fn traverse(&self) -> Option<Node<T>> {
88        if !self.leaf_nodes.is_empty() {
89            Some(Node::from_idx(self, self.root))
90        } else {
91            None
92        }
93    }
94
95    pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + FusedIterator + Clone + '_ {
96        self.leaf_nodes.iter()
97    }
98
99    pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> + FusedIterator + '_ {
100        self.leaf_nodes.iter_mut()
101    }
102
103    pub fn par_iter(&self) -> impl IndexedParallelIterator<Item = &T> + Clone + '_ {
104        self.leaf_nodes.par_iter()
105    }
106
107    pub fn par_iter_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut T> + '_ {
108        self.leaf_nodes.par_iter_mut()
109    }
110}
111
112impl<T: Bounded3D + Send + Sync> Default for Bvh<T> {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118#[derive(Debug)]
119pub enum Node<'a, T> {
120    Internal(Internal<'a, T>),
121    Leaf(&'a T),
122}
123
124impl<'a, T> Node<'a, T> {
125    fn from_idx(bvh: &'a Bvh<T>, idx: NodeIdx) -> Self {
126        if idx < bvh.internal_nodes.len() as NodeIdx {
127            Self::Internal(Internal { bvh, idx })
128        } else {
129            Self::Leaf(&bvh.leaf_nodes[(idx - bvh.internal_nodes.len() as NodeIdx) as usize])
130        }
131    }
132}
133
134impl<T: Bounded3D> Bounded3D for Node<'_, T> {
135    fn aabb(&self) -> Aabb<f64> {
136        match self {
137            Node::Internal(int) => int.aabb(),
138            Node::Leaf(t) => t.aabb(),
139        }
140    }
141}
142
143impl<T> Clone for Node<'_, T> {
144    fn clone(&self) -> Self {
145        *self
146    }
147}
148
149impl<T> Copy for Node<'_, T> {}
150
151#[derive(Debug)]
152pub struct Internal<'a, T> {
153    bvh: &'a Bvh<T>,
154    idx: NodeIdx,
155}
156
157impl<'a, T> Internal<'a, T> {
158    pub fn split(self) -> (Aabb<f64>, Node<'a, T>, Node<'a, T>) {
159        let internal = &self.bvh.internal_nodes[self.idx as usize];
160
161        let bb = internal.bb;
162        let left = Node::from_idx(self.bvh, internal.left);
163        let right = Node::from_idx(self.bvh, internal.right);
164
165        (bb, left, right)
166    }
167}
168
169impl<T> Bounded3D for Internal<'_, T> {
170    fn aabb(&self) -> Aabb<f64> {
171        self.bvh.internal_nodes[self.idx as usize].bb
172    }
173}
174
175impl<T> Clone for Internal<'_, T> {
176    fn clone(&self) -> Self {
177        *self
178    }
179}
180
181impl<T> Copy for Internal<'_, T> {}
182
183fn rebuild_rec<T: Send + Bounded3D>(
184    idx: NodeIdx,
185    mut bounds: Aabb<f64>,
186    internal_nodes: &mut [InternalNode],
187    leaf_nodes: &mut [T],
188    total_leaf_count: NodeIdx,
189) -> (NodeIdx, Aabb<f64>) {
190    debug_assert_eq!(leaf_nodes.len() - 1, internal_nodes.len());
191
192    if leaf_nodes.len() == 1 {
193        // Leaf node
194        return (total_leaf_count - 1 + idx, leaf_nodes[0].aabb());
195    }
196
197    loop {
198        debug_assert!(bounds.is_valid());
199        let dims = bounds.max - bounds.min;
200
201        let (mut split, bounds_left, bounds_right) = if dims.x >= dims.y && dims.x >= dims.z {
202            let mid = middle(bounds.min.x, bounds.max.x);
203            let [bounds_left, bounds_right] = bounds.split_at_x(mid);
204
205            let p = partition(leaf_nodes, |l| {
206                middle(l.aabb().min.x, l.aabb().max.x) <= mid
207            });
208
209            (p, bounds_left, bounds_right)
210        } else if dims.y >= dims.x && dims.y >= dims.z {
211            let mid = middle(bounds.min.y, bounds.max.y);
212            let [bounds_left, bounds_right] = bounds.split_at_y(mid);
213
214            let p = partition(leaf_nodes, |l| {
215                middle(l.aabb().min.y, l.aabb().max.y) <= mid
216            });
217
218            (p, bounds_left, bounds_right)
219        } else {
220            let mid = middle(bounds.min.z, bounds.max.z);
221            let [bounds_left, bounds_right] = bounds.split_at_z(mid);
222
223            let p = partition(leaf_nodes, |l| {
224                middle(l.aabb().min.z, l.aabb().max.z) <= mid
225            });
226
227            (p, bounds_left, bounds_right)
228        };
229
230        // Check if one of the halves is empty. (We can't have empty nodes)
231        // Also take care to handle the edge case of overlapping points.
232        if split == 0 {
233            if abs_diff_eq!(
234                bounds_right.min,
235                bounds_right.max,
236                epsilon = f64::EPSILON * 100.0
237            ) {
238                split += 1;
239            } else {
240                bounds = bounds_right;
241                continue;
242            }
243        } else if split == leaf_nodes.len() {
244            if abs_diff_eq!(
245                bounds_left.min,
246                bounds_left.max,
247                epsilon = f64::EPSILON * 100.0
248            ) {
249                split -= 1;
250            } else {
251                bounds = bounds_left;
252                continue;
253            }
254        }
255
256        let (leaves_left, leaves_right) = leaf_nodes.split_at_mut(split);
257
258        let (internal_left, internal_right) = internal_nodes.split_at_mut(split);
259        let (internal, internal_left) = internal_left.split_last_mut().unwrap();
260
261        let ((left, bounds_left), (right, bounds_right)) = rayon::join(
262            || {
263                rebuild_rec(
264                    idx,
265                    bounds_left,
266                    internal_left,
267                    leaves_left,
268                    total_leaf_count,
269                )
270            },
271            || {
272                rebuild_rec(
273                    idx + split as NodeIdx,
274                    bounds_right,
275                    internal_right,
276                    leaves_right,
277                    total_leaf_count,
278                )
279            },
280        );
281
282        internal.bb = bounds_left.union(bounds_right);
283        internal.left = left;
284        internal.right = right;
285
286        break (idx + split as NodeIdx - 1, internal.bb);
287    }
288}
289
290fn partition<T>(s: &mut [T], mut pred: impl FnMut(&T) -> bool) -> usize {
291    let mut it = s.iter_mut();
292    let mut true_count = 0;
293
294    while let Some(head) = it.find(|x| {
295        if pred(x) {
296            true_count += 1;
297            false
298        } else {
299            true
300        }
301    }) {
302        if let Some(tail) = it.rfind(|x| pred(x)) {
303            mem::swap(head, tail);
304            true_count += 1;
305        } else {
306            break;
307        }
308    }
309    true_count
310}
311
312fn middle(a: f64, b: f64) -> f64 {
313    (a + b) / 2.0
314}
315
316impl<O: Bounded3D + Send + Sync> SpatialIndex for Bvh<O> {
317    type Object = O;
318
319    fn query<C, F, T>(&self, mut collides: C, mut f: F) -> Option<T>
320    where
321        C: FnMut(Aabb<f64>) -> bool,
322        F: FnMut(&O) -> Option<T>,
323    {
324        fn query_rec<C, F, O, T>(node: Node<O>, collides: &mut C, f: &mut F) -> Option<T>
325        where
326            C: FnMut(Aabb<f64>) -> bool,
327            F: FnMut(&O) -> Option<T>,
328            O: Bounded3D,
329        {
330            match node {
331                Node::Internal(int) => {
332                    let (bb, left, right) = int.split();
333
334                    if collides(bb) {
335                        query_rec(left, collides, f).or_else(|| query_rec(right, collides, f))
336                    } else {
337                        None
338                    }
339                }
340                Node::Leaf(leaf) => {
341                    if collides(leaf.aabb()) {
342                        f(leaf)
343                    } else {
344                        None
345                    }
346                }
347            }
348        }
349
350        query_rec(self.traverse()?, &mut collides, &mut f)
351    }
352
353    fn raycast<F>(&self, origin: Vec3<f64>, direction: Vec3<f64>, mut f: F) -> Option<RaycastHit<O>>
354    where
355        F: FnMut(RaycastHit<O>) -> bool,
356    {
357        fn raycast_rec<'a, O: Bounded3D>(
358            node: Node<'a, O>,
359            hit: &mut Option<RaycastHit<'a, O>>,
360            near: f64,
361            far: f64,
362            origin: Vec3<f64>,
363            direction: Vec3<f64>,
364            f: &mut impl FnMut(RaycastHit<O>) -> bool,
365        ) {
366            if let Some(hit) = hit {
367                if hit.near <= near {
368                    return;
369                }
370            }
371
372            match node {
373                Node::Internal(int) => {
374                    let (_, left, right) = int.split();
375
376                    let int_left = ray_box_intersect(origin, direction, left.aabb());
377                    let int_right = ray_box_intersect(origin, direction, right.aabb());
378
379                    match (int_left, int_right) {
380                        (Some((near_left, far_left)), Some((near_right, far_right))) => {
381                            // Explore closest subtree first.
382                            if near_left < near_right {
383                                raycast_rec(left, hit, near_left, far_left, origin, direction, f);
384                                raycast_rec(
385                                    right, hit, near_right, far_right, origin, direction, f,
386                                );
387                            } else {
388                                raycast_rec(
389                                    right, hit, near_right, far_right, origin, direction, f,
390                                );
391                                raycast_rec(left, hit, near_left, far_left, origin, direction, f);
392                            }
393                        }
394                        (Some((near, far)), None) => {
395                            raycast_rec(left, hit, near, far, origin, direction, f)
396                        }
397                        (None, Some((near, far))) => {
398                            raycast_rec(right, hit, near, far, origin, direction, f)
399                        }
400                        (None, None) => {}
401                    }
402                }
403                Node::Leaf(leaf) => {
404                    let this_hit = RaycastHit {
405                        object: leaf,
406                        near,
407                        far,
408                    };
409
410                    if f(this_hit) {
411                        *hit = Some(this_hit);
412                    }
413                }
414            }
415        }
416
417        debug_assert!(
418            direction.is_normalized(),
419            "the ray direction must be normalized"
420        );
421
422        let root = self.traverse()?;
423        let (near, far) = ray_box_intersect(origin, direction, root.aabb())?;
424
425        let mut hit = None;
426        raycast_rec(root, &mut hit, near, far, origin, direction, &mut f);
427        hit
428    }
429}