1use std::iter::FusedIterator;
2use std::mem;
3
4use approx::abs_diff_eq;
5use rayon::iter::{
6 IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
7};
8use vek::{Aabb, Vec3};
9
10use crate::{ray_box_intersect, Bounded3D, RaycastHit, SpatialIndex};
11
12#[derive(Clone, Debug)]
13pub struct Bvh<T> {
14 internal_nodes: Vec<InternalNode>,
15 leaf_nodes: Vec<T>,
16 root: NodeIdx,
17}
18
19#[derive(Clone, Debug)]
20struct InternalNode {
21 bb: Aabb<f64>,
22 left: NodeIdx,
23 right: NodeIdx,
24}
25
26type NodeIdx = u32;
28
29impl<T: Bounded3D + Send + Sync> Bvh<T> {
30 pub fn new() -> Self {
31 Self {
32 internal_nodes: vec![],
33 leaf_nodes: vec![],
34 root: NodeIdx::MAX,
35 }
36 }
37
38 pub fn rebuild<I: IntoIterator<Item = T>>(&mut self, leaves: I) {
39 self.internal_nodes.clear();
40 self.leaf_nodes.clear();
41
42 self.leaf_nodes.extend(leaves);
43
44 let leaf_count = self.leaf_nodes.len();
45
46 if leaf_count == 0 {
47 return;
48 }
49
50 self.internal_nodes.reserve_exact(leaf_count - 1);
51 self.internal_nodes.resize(
52 leaf_count - 1,
53 InternalNode {
54 bb: Aabb::default(),
55 left: NodeIdx::MAX,
56 right: NodeIdx::MAX,
57 },
58 );
59
60 if NodeIdx::try_from(leaf_count)
61 .ok()
62 .and_then(|count| count.checked_add(count - 1))
63 .is_none()
64 {
65 panic!("too many elements in BVH");
66 }
67
68 let id = self.leaf_nodes[0].aabb();
69 let scene_bounds = self
70 .leaf_nodes
71 .par_iter()
72 .map(|l| l.aabb())
73 .reduce(|| id, Aabb::union);
74
75 self.root = rebuild_rec(
76 0,
77 scene_bounds,
78 &mut self.internal_nodes,
79 &mut self.leaf_nodes,
80 leaf_count as NodeIdx,
81 )
82 .0;
83
84 debug_assert_eq!(self.internal_nodes.len(), self.leaf_nodes.len() - 1);
85 }
86
87 pub fn traverse(&self) -> Option<Node<T>> {
88 if !self.leaf_nodes.is_empty() {
89 Some(Node::from_idx(self, self.root))
90 } else {
91 None
92 }
93 }
94
95 pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + FusedIterator + Clone + '_ {
96 self.leaf_nodes.iter()
97 }
98
99 pub fn iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> + FusedIterator + '_ {
100 self.leaf_nodes.iter_mut()
101 }
102
103 pub fn par_iter(&self) -> impl IndexedParallelIterator<Item = &T> + Clone + '_ {
104 self.leaf_nodes.par_iter()
105 }
106
107 pub fn par_iter_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut T> + '_ {
108 self.leaf_nodes.par_iter_mut()
109 }
110}
111
112impl<T: Bounded3D + Send + Sync> Default for Bvh<T> {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118#[derive(Debug)]
119pub enum Node<'a, T> {
120 Internal(Internal<'a, T>),
121 Leaf(&'a T),
122}
123
124impl<'a, T> Node<'a, T> {
125 fn from_idx(bvh: &'a Bvh<T>, idx: NodeIdx) -> Self {
126 if idx < bvh.internal_nodes.len() as NodeIdx {
127 Self::Internal(Internal { bvh, idx })
128 } else {
129 Self::Leaf(&bvh.leaf_nodes[(idx - bvh.internal_nodes.len() as NodeIdx) as usize])
130 }
131 }
132}
133
134impl<T: Bounded3D> Bounded3D for Node<'_, T> {
135 fn aabb(&self) -> Aabb<f64> {
136 match self {
137 Node::Internal(int) => int.aabb(),
138 Node::Leaf(t) => t.aabb(),
139 }
140 }
141}
142
143impl<T> Clone for Node<'_, T> {
144 fn clone(&self) -> Self {
145 *self
146 }
147}
148
149impl<T> Copy for Node<'_, T> {}
150
151#[derive(Debug)]
152pub struct Internal<'a, T> {
153 bvh: &'a Bvh<T>,
154 idx: NodeIdx,
155}
156
157impl<'a, T> Internal<'a, T> {
158 pub fn split(self) -> (Aabb<f64>, Node<'a, T>, Node<'a, T>) {
159 let internal = &self.bvh.internal_nodes[self.idx as usize];
160
161 let bb = internal.bb;
162 let left = Node::from_idx(self.bvh, internal.left);
163 let right = Node::from_idx(self.bvh, internal.right);
164
165 (bb, left, right)
166 }
167}
168
169impl<T> Bounded3D for Internal<'_, T> {
170 fn aabb(&self) -> Aabb<f64> {
171 self.bvh.internal_nodes[self.idx as usize].bb
172 }
173}
174
175impl<T> Clone for Internal<'_, T> {
176 fn clone(&self) -> Self {
177 *self
178 }
179}
180
181impl<T> Copy for Internal<'_, T> {}
182
183fn rebuild_rec<T: Send + Bounded3D>(
184 idx: NodeIdx,
185 mut bounds: Aabb<f64>,
186 internal_nodes: &mut [InternalNode],
187 leaf_nodes: &mut [T],
188 total_leaf_count: NodeIdx,
189) -> (NodeIdx, Aabb<f64>) {
190 debug_assert_eq!(leaf_nodes.len() - 1, internal_nodes.len());
191
192 if leaf_nodes.len() == 1 {
193 return (total_leaf_count - 1 + idx, leaf_nodes[0].aabb());
195 }
196
197 loop {
198 debug_assert!(bounds.is_valid());
199 let dims = bounds.max - bounds.min;
200
201 let (mut split, bounds_left, bounds_right) = if dims.x >= dims.y && dims.x >= dims.z {
202 let mid = middle(bounds.min.x, bounds.max.x);
203 let [bounds_left, bounds_right] = bounds.split_at_x(mid);
204
205 let p = partition(leaf_nodes, |l| {
206 middle(l.aabb().min.x, l.aabb().max.x) <= mid
207 });
208
209 (p, bounds_left, bounds_right)
210 } else if dims.y >= dims.x && dims.y >= dims.z {
211 let mid = middle(bounds.min.y, bounds.max.y);
212 let [bounds_left, bounds_right] = bounds.split_at_y(mid);
213
214 let p = partition(leaf_nodes, |l| {
215 middle(l.aabb().min.y, l.aabb().max.y) <= mid
216 });
217
218 (p, bounds_left, bounds_right)
219 } else {
220 let mid = middle(bounds.min.z, bounds.max.z);
221 let [bounds_left, bounds_right] = bounds.split_at_z(mid);
222
223 let p = partition(leaf_nodes, |l| {
224 middle(l.aabb().min.z, l.aabb().max.z) <= mid
225 });
226
227 (p, bounds_left, bounds_right)
228 };
229
230 if split == 0 {
233 if abs_diff_eq!(
234 bounds_right.min,
235 bounds_right.max,
236 epsilon = f64::EPSILON * 100.0
237 ) {
238 split += 1;
239 } else {
240 bounds = bounds_right;
241 continue;
242 }
243 } else if split == leaf_nodes.len() {
244 if abs_diff_eq!(
245 bounds_left.min,
246 bounds_left.max,
247 epsilon = f64::EPSILON * 100.0
248 ) {
249 split -= 1;
250 } else {
251 bounds = bounds_left;
252 continue;
253 }
254 }
255
256 let (leaves_left, leaves_right) = leaf_nodes.split_at_mut(split);
257
258 let (internal_left, internal_right) = internal_nodes.split_at_mut(split);
259 let (internal, internal_left) = internal_left.split_last_mut().unwrap();
260
261 let ((left, bounds_left), (right, bounds_right)) = rayon::join(
262 || {
263 rebuild_rec(
264 idx,
265 bounds_left,
266 internal_left,
267 leaves_left,
268 total_leaf_count,
269 )
270 },
271 || {
272 rebuild_rec(
273 idx + split as NodeIdx,
274 bounds_right,
275 internal_right,
276 leaves_right,
277 total_leaf_count,
278 )
279 },
280 );
281
282 internal.bb = bounds_left.union(bounds_right);
283 internal.left = left;
284 internal.right = right;
285
286 break (idx + split as NodeIdx - 1, internal.bb);
287 }
288}
289
290fn partition<T>(s: &mut [T], mut pred: impl FnMut(&T) -> bool) -> usize {
291 let mut it = s.iter_mut();
292 let mut true_count = 0;
293
294 while let Some(head) = it.find(|x| {
295 if pred(x) {
296 true_count += 1;
297 false
298 } else {
299 true
300 }
301 }) {
302 if let Some(tail) = it.rfind(|x| pred(x)) {
303 mem::swap(head, tail);
304 true_count += 1;
305 } else {
306 break;
307 }
308 }
309 true_count
310}
311
312fn middle(a: f64, b: f64) -> f64 {
313 (a + b) / 2.0
314}
315
316impl<O: Bounded3D + Send + Sync> SpatialIndex for Bvh<O> {
317 type Object = O;
318
319 fn query<C, F, T>(&self, mut collides: C, mut f: F) -> Option<T>
320 where
321 C: FnMut(Aabb<f64>) -> bool,
322 F: FnMut(&O) -> Option<T>,
323 {
324 fn query_rec<C, F, O, T>(node: Node<O>, collides: &mut C, f: &mut F) -> Option<T>
325 where
326 C: FnMut(Aabb<f64>) -> bool,
327 F: FnMut(&O) -> Option<T>,
328 O: Bounded3D,
329 {
330 match node {
331 Node::Internal(int) => {
332 let (bb, left, right) = int.split();
333
334 if collides(bb) {
335 query_rec(left, collides, f).or_else(|| query_rec(right, collides, f))
336 } else {
337 None
338 }
339 }
340 Node::Leaf(leaf) => {
341 if collides(leaf.aabb()) {
342 f(leaf)
343 } else {
344 None
345 }
346 }
347 }
348 }
349
350 query_rec(self.traverse()?, &mut collides, &mut f)
351 }
352
353 fn raycast<F>(&self, origin: Vec3<f64>, direction: Vec3<f64>, mut f: F) -> Option<RaycastHit<O>>
354 where
355 F: FnMut(RaycastHit<O>) -> bool,
356 {
357 fn raycast_rec<'a, O: Bounded3D>(
358 node: Node<'a, O>,
359 hit: &mut Option<RaycastHit<'a, O>>,
360 near: f64,
361 far: f64,
362 origin: Vec3<f64>,
363 direction: Vec3<f64>,
364 f: &mut impl FnMut(RaycastHit<O>) -> bool,
365 ) {
366 if let Some(hit) = hit {
367 if hit.near <= near {
368 return;
369 }
370 }
371
372 match node {
373 Node::Internal(int) => {
374 let (_, left, right) = int.split();
375
376 let int_left = ray_box_intersect(origin, direction, left.aabb());
377 let int_right = ray_box_intersect(origin, direction, right.aabb());
378
379 match (int_left, int_right) {
380 (Some((near_left, far_left)), Some((near_right, far_right))) => {
381 if near_left < near_right {
383 raycast_rec(left, hit, near_left, far_left, origin, direction, f);
384 raycast_rec(
385 right, hit, near_right, far_right, origin, direction, f,
386 );
387 } else {
388 raycast_rec(
389 right, hit, near_right, far_right, origin, direction, f,
390 );
391 raycast_rec(left, hit, near_left, far_left, origin, direction, f);
392 }
393 }
394 (Some((near, far)), None) => {
395 raycast_rec(left, hit, near, far, origin, direction, f)
396 }
397 (None, Some((near, far))) => {
398 raycast_rec(right, hit, near, far, origin, direction, f)
399 }
400 (None, None) => {}
401 }
402 }
403 Node::Leaf(leaf) => {
404 let this_hit = RaycastHit {
405 object: leaf,
406 near,
407 far,
408 };
409
410 if f(this_hit) {
411 *hit = Some(this_hit);
412 }
413 }
414 }
415 }
416
417 debug_assert!(
418 direction.is_normalized(),
419 "the ray direction must be normalized"
420 );
421
422 let root = self.traverse()?;
423 let (near, far) = ray_box_intersect(origin, direction, root.aabb())?;
424
425 let mut hit = None;
426 raycast_rec(root, &mut hit, near, far, origin, direction, &mut f);
427 hit
428 }
429}