1#![expect(
11 clippy::manual_midpoint,
12 reason = "must match vanilla's truncate-toward-zero behavior"
13)]
14
15use std::cmp::Ordering;
16
17use super::PARAMETER_COUNT;
18use super::types::{Parameter, ParameterPoint, TargetPoint};
19
20const CHILDREN_PER_NODE: usize = 6;
22
23enum RTreeNode {
26 Leaf {
28 parameter_space: [Parameter; PARAMETER_COUNT],
29 value_index: usize,
30 },
31 SubTree {
33 parameter_space: [Parameter; PARAMETER_COUNT],
34 children: Vec<RTreeNode>,
35 },
36}
37
38impl RTreeNode {
39 const fn parameter_space(&self) -> &[Parameter; PARAMETER_COUNT] {
41 match self {
42 Self::Leaf {
43 parameter_space, ..
44 }
45 | Self::SubTree {
46 parameter_space, ..
47 } => parameter_space,
48 }
49 }
50}
51
52#[derive(Clone)]
54struct BuildEntry {
55 parameter_space: [Parameter; PARAMETER_COUNT],
56 index: usize,
57}
58
59fn build_parameter_space(children: &[RTreeNode]) -> [Parameter; PARAMETER_COUNT] {
61 let mut bounds: [Option<Parameter>; PARAMETER_COUNT] = [None; PARAMETER_COUNT];
62 for child in children {
63 let ps = child.parameter_space();
64 for d in 0..PARAMETER_COUNT {
65 bounds[d] = Some(ps[d].span_with(bounds[d].as_ref()));
66 }
67 }
68 bounds.map(|b| b.expect("bounds should be initialized"))
69}
70
71fn cost(parameter_space: &[Parameter; PARAMETER_COUNT]) -> i64 {
73 let mut result = 0i64;
74 for p in parameter_space {
75 result += (p.max - p.min).abs();
76 }
77 result
78}
79
80fn build_tree(entries: &mut [BuildEntry]) -> RTreeNode {
82 assert!(!entries.is_empty());
83
84 if entries.len() == 1 {
85 return RTreeNode::Leaf {
86 parameter_space: entries[0].parameter_space,
87 value_index: entries[0].index,
88 };
89 }
90
91 if entries.len() <= CHILDREN_PER_NODE {
92 entries.sort_by_key(|e| {
94 let mut total: i64 = 0;
95 for d in 0..PARAMETER_COUNT {
96 let p = &e.parameter_space[d];
97 total += ((p.min + p.max) / 2).abs();
98 }
99 total
100 });
101
102 let children: Vec<RTreeNode> = entries
103 .iter()
104 .map(|e| RTreeNode::Leaf {
105 parameter_space: e.parameter_space,
106 value_index: e.index,
107 })
108 .collect();
109 let ps = build_parameter_space(&children);
110 return RTreeNode::SubTree {
111 parameter_space: ps,
112 children,
113 };
114 }
115
116 let mut min_cost = i64::MAX;
122 let mut best_dim = 0;
123 let mut best_buckets: Option<Vec<Vec<BuildEntry>>> = None;
124
125 for d in 0..PARAMETER_COUNT {
126 sort_entries(entries, d);
127 let (bucket_cost, buckets) = snapshot_buckets(entries);
128 if min_cost > bucket_cost {
129 min_cost = bucket_cost;
130 best_dim = d;
131 best_buckets = Some(buckets);
132 }
133 }
134
135 let buckets = best_buckets.expect("should have found at least one dimension");
137
138 let mut bucket_subtrees: Vec<([Parameter; PARAMETER_COUNT], Vec<BuildEntry>)> = buckets
140 .into_iter()
141 .map(|bucket_entries| {
142 let mut bounds: [Option<Parameter>; PARAMETER_COUNT] = [None; PARAMETER_COUNT];
143 for e in &bucket_entries {
144 #[expect(clippy::needless_range_loop, reason = "dim indexes parallel arrays")]
145 for dim in 0..PARAMETER_COUNT {
146 bounds[dim] = Some(e.parameter_space[dim].span_with(bounds[dim].as_ref()));
147 }
148 }
149 let ps = bounds.map(|b| b.expect("bounds should be initialized"));
150 (ps, bucket_entries)
151 })
152 .collect();
153
154 sort_bucket_subtrees(&mut bucket_subtrees, best_dim);
156
157 let mut final_children: Vec<RTreeNode> = Vec::new();
159 for (_, mut child_entries) in bucket_subtrees {
160 final_children.push(build_tree(&mut child_entries));
161 }
162
163 let ps = build_parameter_space(&final_children);
164 RTreeNode::SubTree {
165 parameter_space: ps,
166 children: final_children,
167 }
168}
169
170fn sort_entries(entries: &mut [BuildEntry], dimension: usize) {
172 entries.sort_by(|a, b| {
173 for offset in 0..PARAMETER_COUNT {
174 let d = (dimension + offset) % PARAMETER_COUNT;
175 let center_a = (a.parameter_space[d].min + a.parameter_space[d].max) / 2;
176 let center_b = (b.parameter_space[d].min + b.parameter_space[d].max) / 2;
177 let cmp = center_a.cmp(¢er_b);
178 if cmp != Ordering::Equal {
179 return cmp;
180 }
181 }
182 Ordering::Equal
183 });
184}
185
186fn sort_bucket_subtrees(
189 subtrees: &mut [([Parameter; PARAMETER_COUNT], Vec<BuildEntry>)],
190 dimension: usize,
191) {
192 subtrees.sort_by(|a, b| {
193 for offset in 0..PARAMETER_COUNT {
194 let d = (dimension + offset) % PARAMETER_COUNT;
195 let center_a = (a.0[d].min + a.0[d].max) / 2;
196 let center_b = (b.0[d].min + b.0[d].max) / 2;
197 let cmp = center_a.abs().cmp(¢er_b.abs());
198 if cmp != Ordering::Equal {
199 return cmp;
200 }
201 }
202 Ordering::Equal
203 });
204}
205
206fn expected_children_count(total: usize) -> usize {
208 let log_base_6 = ((total as f64) - 0.01).ln() / (CHILDREN_PER_NODE as f64).ln();
209 (CHILDREN_PER_NODE as f64).powf(log_base_6.floor()) as usize
210}
211
212#[expect(
218 clippy::needless_range_loop,
219 reason = "indexing into PARAMETER_COUNT parallel arrays; iterator would be less clear"
220)]
221fn snapshot_buckets(entries: &[BuildEntry]) -> (i64, Vec<Vec<BuildEntry>>) {
222 let expected = expected_children_count(entries.len());
223 let mut buckets = Vec::new();
224 let mut total_cost = 0i64;
225 let mut start = 0;
226 while start < entries.len() {
227 let end = (start + expected).min(entries.len());
228 let bucket = entries[start..end].to_vec();
229 let mut bounds: [Option<Parameter>; PARAMETER_COUNT] = [None; PARAMETER_COUNT];
231 for e in &bucket {
232 for d in 0..PARAMETER_COUNT {
233 bounds[d] = Some(e.parameter_space[d].span_with(bounds[d].as_ref()));
234 }
235 }
236 let ps = bounds.map(|b| b.expect("bounds should be initialized"));
237 total_cost += cost(&ps);
238 buckets.push(bucket);
239 start = end;
240 }
241 (total_cost, buckets)
242}
243
244struct FlatNode {
251 mins: [i64; PARAMETER_COUNT],
253 maxs: [i64; PARAMETER_COUNT],
255 value_index: u32,
258 children_start: u32,
260 children_count: u8,
262}
263
264impl FlatNode {
265 #[inline]
270 #[expect(
271 clippy::needless_range_loop,
272 reason = "indexing into parallel min/max arrays; iterator zip would be less clear"
273 )]
274 fn distance(&self, target: &[i64; PARAMETER_COUNT]) -> i64 {
275 let mut d = 0i64;
276 for i in 0..PARAMETER_COUNT {
277 let di = (target[i] - self.maxs[i])
278 .max(self.mins[i] - target[i])
279 .max(0);
280 d += di * di;
281 }
282 d
283 }
284
285 #[inline]
286 const fn is_leaf(&self) -> bool {
287 self.children_count == 0
288 }
289}
290
291fn flatten_tree(root: RTreeNode) -> Vec<FlatNode> {
296 use std::collections::VecDeque;
297
298 let mut nodes: Vec<FlatNode> = Vec::new();
299 let mut queue: VecDeque<(Vec<RTreeNode>, Option<u32>)> = VecDeque::new();
302 queue.push_back((vec![root], None));
303
304 while let Some((batch, parent_idx)) = queue.pop_front() {
305 let batch_start = nodes.len() as u32;
306
307 if let Some(pidx) = parent_idx {
309 nodes[pidx as usize].children_start = batch_start;
310 }
311
312 for node in batch {
313 let flat_idx = nodes.len() as u32;
314 match node {
315 RTreeNode::Leaf {
316 parameter_space,
317 value_index,
318 } => {
319 nodes.push(FlatNode {
320 mins: parameter_space.map(|p| p.min),
321 maxs: parameter_space.map(|p| p.max),
322 value_index: value_index as u32,
323 children_start: 0,
324 children_count: 0,
325 });
326 }
327 RTreeNode::SubTree {
328 parameter_space,
329 children,
330 } => {
331 let children_count = children.len() as u8;
332 nodes.push(FlatNode {
333 mins: parameter_space.map(|p| p.min),
334 maxs: parameter_space.map(|p| p.max),
335 value_index: u32::MAX,
336 children_start: 0, children_count,
338 });
339 queue.push_back((children, Some(flat_idx)));
340 }
341 }
342 }
343 }
344
345 nodes
346}
347
348fn search_nearest(
353 nodes: &[FlatNode],
354 node: &FlatNode,
355 target: &[i64; PARAMETER_COUNT],
356 best_dist: &mut i64,
357 best_idx: &mut Option<u32>,
358) {
359 let start = node.children_start as usize;
360 let end = start + node.children_count as usize;
361 let children = &nodes[start..end];
362
363 for child in children {
364 let child_dist = child.distance(target);
365 if *best_dist > child_dist {
367 if child.is_leaf() {
368 *best_dist = child_dist;
370 *best_idx = Some(child.value_index);
371 } else {
372 search_nearest(nodes, child, target, best_dist, best_idx);
374 }
375 }
376 }
377}
378
379pub struct ParameterList<T> {
383 values: Vec<(ParameterPoint, T)>,
385 param_spaces: Vec<[Parameter; PARAMETER_COUNT]>,
387 nodes: Vec<FlatNode>,
389}
390
391impl<T> ParameterList<T> {
392 #[must_use]
398 pub fn new(values: Vec<(ParameterPoint, T)>) -> Self {
399 assert!(!values.is_empty(), "Need at least one value");
400
401 let param_spaces: Vec<[Parameter; PARAMETER_COUNT]> =
402 values.iter().map(|(pp, _)| pp.parameter_space()).collect();
403
404 let mut entries: Vec<BuildEntry> = values
406 .iter()
407 .enumerate()
408 .map(|(i, (pp, _))| BuildEntry {
409 parameter_space: pp.parameter_space(),
410 index: i,
411 })
412 .collect();
413
414 let root = build_tree(&mut entries);
415 let nodes = flatten_tree(root);
416
417 Self {
418 values,
419 param_spaces,
420 nodes,
421 }
422 }
423
424 #[must_use]
426 pub fn values(&self) -> &[(ParameterPoint, T)] {
427 &self.values
428 }
429
430 #[must_use]
443 pub fn find_value(&self, target: &TargetPoint) -> &T {
444 let target_array = target.to_parameter_array();
445 let root = &self.nodes[0];
446 if root.is_leaf() {
447 return &self.values[root.value_index as usize].1;
448 }
449 let mut best_dist = i64::MAX;
450 let mut best_idx = None;
451 search_nearest(
452 &self.nodes,
453 root,
454 &target_array,
455 &mut best_dist,
456 &mut best_idx,
457 );
458 &self.values[best_idx.expect("R-Tree search should always find a value") as usize].1
459 }
460
461 #[must_use]
472 pub fn find_value_cached(&self, target: &TargetPoint, cache: &mut Option<usize>) -> &T {
473 let target_array = target.to_parameter_array();
474
475 let root = &self.nodes[0];
476 if root.is_leaf() {
477 let idx = root.value_index as usize;
478 *cache = Some(idx);
479 return &self.values[idx].1;
480 }
481
482 let (mut best_dist, init_idx) = match *cache {
484 Some(idx) => {
485 let ps = &self.param_spaces[idx];
486 let mut d = 0i64;
487 for i in 0..PARAMETER_COUNT {
488 let di = ps[i].distance(target_array[i]);
489 d += di * di;
490 }
491 (d, Some(idx as u32))
492 }
493 None => (i64::MAX, None),
494 };
495
496 let mut best_idx = init_idx;
497 search_nearest(
498 &self.nodes,
499 root,
500 &target_array,
501 &mut best_dist,
502 &mut best_idx,
503 );
504 let result_idx = best_idx.expect("R-Tree search should always find a value") as usize;
505
506 *cache = Some(result_idx);
507 &self.values[result_idx].1
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_parameter_list_find_value() {
517 let values = vec![
518 (
519 ParameterPoint::new(
520 Parameter::new(-10000, 0),
521 Parameter::new(0, 0),
522 Parameter::new(0, 0),
523 Parameter::new(0, 0),
524 Parameter::new(0, 0),
525 Parameter::new(0, 0),
526 0,
527 ),
528 "cold",
529 ),
530 (
531 ParameterPoint::new(
532 Parameter::new(0, 10000),
533 Parameter::new(0, 0),
534 Parameter::new(0, 0),
535 Parameter::new(0, 0),
536 Parameter::new(0, 0),
537 Parameter::new(0, 0),
538 0,
539 ),
540 "hot",
541 ),
542 ];
543
544 let list = ParameterList::new(values);
545
546 let cold_target = TargetPoint::new(-5000, 0, 0, 0, 0, 0);
548 assert_eq!(*list.find_value(&cold_target), "cold");
549
550 let hot_target = TargetPoint::new(5000, 0, 0, 0, 0, 0);
552 assert_eq!(*list.find_value(&hot_target), "hot");
553 }
554}