1use std::mem;
2use std::ops::Range;
3
4use valence_protocol::ChunkPos;
5
6use crate::ChunkView;
7
8#[derive(Clone, Debug)]
10pub struct ChunkBvh<T, const MAX_SURFACE_AREA: i32 = { 8 * 4 }> {
11 nodes: Vec<Node>,
12 values: Vec<T>,
13}
14
15impl<T, const MAX_SURFACE_AREA: i32> Default for ChunkBvh<T, MAX_SURFACE_AREA> {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21#[derive(Clone, Debug)]
22enum Node {
23 Internal {
24 bounds: Aabb,
25 left: NodeIdx,
26 right: NodeIdx,
27 },
28 Leaf {
29 bounds: Aabb,
30 values: Range<NodeIdx>,
32 },
33}
34
35#[cfg(test)]
36impl Node {
37 fn bounds(&self) -> Aabb {
38 match self {
39 Node::Internal { bounds, .. } => *bounds,
40 Node::Leaf { bounds, .. } => *bounds,
41 }
42 }
43}
44
45type NodeIdx = u32;
46
47#[derive(Copy, Clone, PartialEq, Eq, Debug)]
48struct Aabb {
49 min: ChunkPos,
50 max: ChunkPos,
51}
52
53impl Aabb {
54 fn point(pos: ChunkPos) -> Self {
55 Self { min: pos, max: pos }
56 }
57
58 fn surface_area(self) -> i32 {
60 (self.length_x() + self.length_z()) * 2
61 }
62
63 fn union(self, other: Self) -> Self {
65 Self {
66 min: ChunkPos::new(self.min.x.min(other.min.x), self.min.z.min(other.min.z)),
67 max: ChunkPos::new(self.max.x.max(other.max.x), self.max.z.max(other.max.z)),
68 }
69 }
70
71 fn length_x(self) -> i32 {
72 self.max.x - self.min.x
73 }
74
75 fn length_z(self) -> i32 {
76 self.max.z - self.min.z
77 }
78
79 fn intersects(self, other: Self) -> bool {
80 self.min.x <= other.max.x
81 && self.max.x >= other.min.x
82 && self.min.z <= other.max.z
83 && self.max.z >= other.min.z
84 }
85}
86
87pub trait GetChunkPos {
89 fn chunk_pos(&self) -> ChunkPos;
90}
91
92impl GetChunkPos for ChunkPos {
93 fn chunk_pos(&self) -> ChunkPos {
94 *self
95 }
96}
97
98impl<T, const MAX_SURFACE_AREA: i32> ChunkBvh<T, MAX_SURFACE_AREA> {
99 pub fn new() -> Self {
100 assert!(MAX_SURFACE_AREA > 0);
101
102 Self {
103 nodes: vec![],
104 values: vec![],
105 }
106 }
107}
108
109impl<T: GetChunkPos, const MAX_SURFACE_AREA: i32> ChunkBvh<T, MAX_SURFACE_AREA> {
110 pub fn build<I: IntoIterator<Item = T>>(&mut self, items: I) {
111 self.nodes.clear();
112 self.values.clear();
113
114 self.values.extend(items);
115
116 if let Some(bounds) = value_bounds(&self.values) {
117 self.build_rec(bounds, 0..self.values.len());
118 }
119 }
120
121 fn build_rec(&mut self, bounds: Aabb, value_range: Range<usize>) {
122 if bounds.surface_area() <= MAX_SURFACE_AREA {
123 self.nodes.push(Node::Leaf {
124 bounds,
125 values: value_range.start as u32..value_range.end as u32,
126 });
127
128 return;
129 }
130
131 let values = &mut self.values[value_range.clone()];
132
133 let point = if bounds.length_x() >= bounds.length_z() {
138 let mid = middle(bounds.min.x, bounds.max.x);
141 partition(values, |v| v.chunk_pos().x >= mid)
142 } else {
143 let mid = middle(bounds.min.z, bounds.max.z);
146 partition(values, |v| v.chunk_pos().z >= mid)
147 };
148
149 let left_range = value_range.start..value_range.start + point;
150 let right_range = left_range.end..value_range.end;
151
152 let left_bounds =
153 value_bounds(&self.values[left_range.clone()]).expect("left half should be nonempty");
154
155 let right_bounds =
156 value_bounds(&self.values[right_range.clone()]).expect("right half should be nonempty");
157
158 self.build_rec(left_bounds, left_range);
159 let left_idx = (self.nodes.len() - 1) as NodeIdx;
160
161 self.build_rec(right_bounds, right_range);
162 let right_idx = (self.nodes.len() - 1) as NodeIdx;
163
164 self.nodes.push(Node::Internal {
165 bounds,
166 left: left_idx,
167 right: right_idx,
168 });
169 }
170
171 pub fn query<F: FnMut(&T)>(&self, view: ChunkView, mut f: F) {
172 if let Some(root) = self.nodes.last() {
173 let (min, max) = view.bounding_box();
174 self.query_rec(root, view, Aabb { min, max }, &mut f);
175 }
176 }
177
178 fn query_rec<F: FnMut(&T)>(&self, node: &Node, view: ChunkView, view_aabb: Aabb, f: &mut F) {
179 match node {
180 Node::Internal {
181 bounds,
182 left,
183 right,
184 } => {
185 if bounds.intersects(view_aabb) {
186 self.query_rec(&self.nodes[*left as usize], view, view_aabb, f);
187 self.query_rec(&self.nodes[*right as usize], view, view_aabb, f);
188 }
189 }
190 Node::Leaf { bounds, values } => {
191 if bounds.intersects(view_aabb) {
192 for val in &self.values[values.start as usize..values.end as usize] {
193 if view.contains(val.chunk_pos()) {
194 f(val)
195 }
196 }
197 }
198 }
199 }
200 }
201
202 pub fn shrink_to_fit(&mut self) {
203 self.nodes.shrink_to_fit();
204 self.values.shrink_to_fit();
205 }
206
207 #[cfg(test)]
208 fn check_invariants(&self) {
209 if let Some(root) = self.nodes.last() {
210 self.check_invariants_rec(root);
211 }
212 }
213
214 #[cfg(test)]
215 fn check_invariants_rec(&self, node: &Node) {
216 match node {
217 Node::Internal {
218 bounds,
219 left,
220 right,
221 } => {
222 let left = &self.nodes[*left as usize];
223 let right = &self.nodes[*right as usize];
224
225 assert_eq!(left.bounds().union(right.bounds()), *bounds);
226
227 self.check_invariants_rec(left);
228 self.check_invariants_rec(right);
229 }
230 Node::Leaf {
231 bounds: leaf_bounds,
232 values,
233 } => {
234 let bounds = value_bounds(&self.values[values.start as usize..values.end as usize])
235 .expect("leaf should be nonempty");
236
237 assert_eq!(*leaf_bounds, bounds);
238 }
239 }
240 }
241}
242
243fn value_bounds<T: GetChunkPos>(values: &[T]) -> Option<Aabb> {
244 values
245 .iter()
246 .map(|v| Aabb::point(v.chunk_pos()))
247 .reduce(Aabb::union)
248}
249
250fn middle(min: i32, max: i32) -> i32 {
251 ((i64::from(min) + i64::from(max)) / 2) as i32
253}
254
255fn partition<T>(s: &mut [T], mut pred: impl FnMut(&T) -> bool) -> usize {
258 let mut it = s.iter_mut();
259 let mut true_count = 0;
260
261 while let Some(head) = it.find(|x| {
262 if pred(x) {
263 true_count += 1;
264 false
265 } else {
266 true
267 }
268 }) {
269 if let Some(tail) = it.rfind(|x| pred(x)) {
270 mem::swap(head, tail);
271 true_count += 1;
272 } else {
273 break;
274 }
275 }
276 true_count
277}
278
279#[cfg(test)]
280mod tests {
281 use rand::Rng;
282
283 use super::*;
284
285 #[test]
286 fn partition_middle() {
287 let mut arr = [2, 3, 4, 5];
288 let mid = middle(arr[0], arr[arr.len() - 1]);
289
290 let point = partition(&mut arr, |&x| mid >= x);
291
292 assert_eq!(point, 2);
293 assert_eq!(&arr[..point], &[2, 3]);
294 assert_eq!(&arr[point..], &[4, 5]);
295 }
296
297 #[test]
298 fn query_visits_correct_nodes() {
299 let mut bvh = ChunkBvh::<ChunkPos>::new();
300
301 let mut positions = vec![];
302
303 let size = 500;
304 let mut rng = rand::thread_rng();
305
306 for _ in 0..100_000 {
308 positions.push(ChunkPos {
309 x: rng.gen_range(-size / 2..size / 2),
310 z: rng.gen_range(-size / 2..size / 2),
311 });
312 }
313
314 let view = ChunkView::new(ChunkPos::default(), 32);
316
317 let mut viewed_positions = vec![];
318
319 for &pos in &positions {
321 if view.contains(pos) {
322 viewed_positions.push(pos);
323 }
324 }
325
326 bvh.build(positions);
327
328 bvh.check_invariants();
329
330 bvh.query(view, |pos| {
333 let idx = viewed_positions.iter().position(|p| p == pos).expect("😔");
334 viewed_positions.remove(idx);
335 });
336
337 assert!(viewed_positions.is_empty());
338 }
339}