ATTPCROOT  0.3.0-alpha
A ROOT-based framework for analyzing data from active target detectors
kdtree.cxx
Go to the documentation of this file.
1 //
2 // KdTree implementation.
3 //
4 // Copyright: Christoph Dalitz, 2018
5 // Jens Wilberg, 2018
6 // License: BSD style license
7 // (see the file LICENSE for details)
8 //
9 
10 #include "kdtree.hpp"
11 // IWYU pragma: no_include <ext/alloc_traits.h>
12 #include <math.h>
13 
14 #include <algorithm>
15 #include <limits>
16 #include <memory>
17 #include <stdexcept>
18 
19 namespace Kdtree {
20 
21 //--------------------------------------------------------------
22 // function object for comparing only dimension d of two vecotrs
23 //--------------------------------------------------------------
25 public:
26  compare_dimension(size_t dim) { d = dim; }
27  bool operator()(const KdNode &p, const KdNode &q) { return (p.point[d] < q.point[d]); }
28  size_t d;
29 };
30 
31 //--------------------------------------------------------------
32 // internal node structure used by kdtree
33 //--------------------------------------------------------------
34 class kdtree_node {
35 public:
37  {
38  dataindex = cutdim = 0;
39  loson = hison = (kdtree_node *)NULL;
40  }
42  {
43  if (loson)
44  delete loson;
45  if (hison)
46  delete hison;
47  }
48  // index of node data in kdtree array "allnodes"
49  size_t dataindex;
50  // cutting dimension
51  size_t cutdim;
52  // value of point
53  // double cutval; // == point[cutdim]
55  // roots of the two subtrees
57  // bounding rectangle of this node's subtree
59 };
60 
61 //--------------------------------------------------------------
62 // different distance metrics
63 //--------------------------------------------------------------
65 public:
67  virtual ~DistanceMeasure() {}
68  virtual double distance(const CoordPoint &p, const CoordPoint &q) = 0;
69  virtual double coordinate_distance(double x, double y, size_t dim) = 0;
70 };
71 // Maximum distance (Linfinite norm)
72 class DistanceL0 : virtual public DistanceMeasure {
73  DoubleVector *w;
74 
75 public:
76  DistanceL0(const DoubleVector *weights = NULL)
77  {
78  if (weights)
79  w = new DoubleVector(*weights);
80  else
81  w = (DoubleVector *)NULL;
82  }
84  {
85  if (w)
86  delete w;
87  }
88  double distance(const CoordPoint &p, const CoordPoint &q)
89  {
90  size_t i;
91  double dist, test;
92  if (w) {
93  dist = (*w)[0] * fabs(p[0] - q[0]);
94  for (i = 1; i < p.size(); i++) {
95  test = (*w)[i] * fabs(p[i] - q[i]);
96  if (test > dist)
97  dist = test;
98  }
99  } else {
100  dist = fabs(p[0] - q[0]);
101  for (i = 1; i < p.size(); i++) {
102  test = fabs(p[i] - q[i]);
103  if (test > dist)
104  dist = test;
105  }
106  }
107  return dist;
108  }
109  double coordinate_distance(double x, double y, size_t dim)
110  {
111  if (w)
112  return (*w)[dim] * fabs(x - y);
113  else
114  return fabs(x - y);
115  }
116 };
117 // Manhatten distance (L1 norm)
118 class DistanceL1 : virtual public DistanceMeasure {
119  DoubleVector *w;
120 
121 public:
122  DistanceL1(const DoubleVector *weights = NULL)
123  {
124  if (weights)
125  w = new DoubleVector(*weights);
126  else
127  w = (DoubleVector *)NULL;
128  }
130  {
131  if (w)
132  delete w;
133  }
134  double distance(const CoordPoint &p, const CoordPoint &q)
135  {
136  size_t i;
137  double dist = 0.0;
138  if (w) {
139  for (i = 0; i < p.size(); i++)
140  dist += (*w)[i] * fabs(p[i] - q[i]);
141  } else {
142  for (i = 0; i < p.size(); i++)
143  dist += fabs(p[i] - q[i]);
144  }
145  return dist;
146  }
147  double coordinate_distance(double x, double y, size_t dim)
148  {
149  if (w)
150  return (*w)[dim] * fabs(x - y);
151  else
152  return fabs(x - y);
153  }
154 };
155 // Euklidean distance (L2 norm)
156 class DistanceL2 : virtual public DistanceMeasure {
157  DoubleVector *w;
158 
159 public:
160  DistanceL2(const DoubleVector *weights = NULL)
161  {
162  if (weights)
163  w = new DoubleVector(*weights);
164  else
165  w = (DoubleVector *)NULL;
166  }
168  {
169  if (w)
170  delete w;
171  }
172  double distance(const CoordPoint &p, const CoordPoint &q)
173  {
174  size_t i;
175  double dist = 0.0;
176  if (w) {
177  for (i = 0; i < p.size(); i++)
178  dist += (*w)[i] * (p[i] - q[i]) * (p[i] - q[i]);
179  } else {
180  for (i = 0; i < p.size(); i++)
181  dist += (p[i] - q[i]) * (p[i] - q[i]);
182  }
183  return dist;
184  }
185  double coordinate_distance(double x, double y, size_t dim)
186  {
187  if (w)
188  return (*w)[dim] * (x - y) * (x - y);
189  else
190  return (x - y) * (x - y);
191  }
192 };
193 
194 //--------------------------------------------------------------
195 // destructor and constructor of kdtree
196 //--------------------------------------------------------------
198 {
199  if (root)
200  delete root;
201  delete distance;
202 }
203 // distance_type can be 0 (Maximum), 1 (Manhatten), or 2 (Euklidean)
204 KdTree::KdTree(const KdNodeVector *nodes, int distance_type_ /*=2*/)
205 {
206  size_t i, j;
207  double val;
208  // copy over input data
209  dimension = nodes->begin()->point.size();
210  allnodes = *nodes;
211  // initialize distance values
212  distance = NULL;
213  this->distance_type = -1;
214  set_distance(distance_type_);
215  // compute global bounding box
216  lobound = nodes->begin()->point;
217  upbound = nodes->begin()->point;
218  for (i = 1; i < nodes->size(); i++) {
219  for (j = 0; j < dimension; j++) {
220  val = allnodes[i].point[j];
221  if (lobound[j] > val)
222  lobound[j] = val;
223  if (upbound[j] < val)
224  upbound[j] = val;
225  }
226  }
227  // build tree recursively
228  root = build_tree(0, 0, allnodes.size());
229 }
230 
231 // distance_type can be 0 (Maximum), 1 (Manhatten), or 2 (Euklidean)
232 void KdTree::set_distance(int distance_type_, const DoubleVector *weights /*=NULL*/)
233 {
234  if (distance)
235  delete distance;
236  this->distance_type = distance_type_;
237  if (distance_type_ == 0) {
238  distance = (DistanceMeasure *)new DistanceL0(weights);
239  } else if (distance_type_ == 1) {
240  distance = (DistanceMeasure *)new DistanceL1(weights);
241  } else {
242  distance = (DistanceMeasure *)new DistanceL2(weights);
243  }
244 }
245 
246 //--------------------------------------------------------------
247 // recursive build of tree
248 // "a" and "b"-1 are the lower and upper indices
249 // from "allnodes" from which the subtree is to be built
250 //--------------------------------------------------------------
251 kdtree_node *KdTree::build_tree(size_t depth, size_t a, size_t b)
252 {
253  size_t m;
254  double temp, cutval;
255  kdtree_node *node = new kdtree_node();
256  node->lobound = lobound;
257  node->upbound = upbound;
258  node->cutdim = depth % dimension; // NOLINT
259  if (b - a <= 1) {
260  node->dataindex = a;
261  node->point = allnodes[a].point;
262  } else {
263  m = (a + b) / 2;
264  std::nth_element(allnodes.begin() + a, allnodes.begin() + m, allnodes.begin() + b,
265  compare_dimension(node->cutdim));
266  node->point = allnodes[m].point;
267  cutval = allnodes[m].point[node->cutdim];
268  node->dataindex = m;
269  if (m - a > 0) {
270  temp = upbound[node->cutdim];
271  upbound[node->cutdim] = cutval;
272  node->loson = build_tree(depth + 1, a, m);
273  upbound[node->cutdim] = temp;
274  }
275  if (b - m > 1) {
276  temp = lobound[node->cutdim];
277  lobound[node->cutdim] = cutval;
278  node->hison = build_tree(depth + 1, m + 1, b);
279  lobound[node->cutdim] = temp;
280  }
281  }
282  return node;
283 }
284 
285 //--------------------------------------------------------------
286 // k nearest neighbor search
287 // returns the *k* nearest neighbors of *point* in O(log(n))
288 // time. The result is returned in *result* and is sorted by
289 // distance from *point*. Also the distances from *point* are
290 // returned in *distances*.
291 // The optional search predicate is a callable class (aka "functor")
292 // derived from KdNodePredicate. When Null (default, no search
293 // predicate is applied).
294 //--------------------------------------------------------------
295 void KdTree::k_nearest_neighbors(const CoordPoint &point, size_t k, KdNodeVector *result,
296  std::vector<double> *distances, KdNodePredicate *pred /*=NULL*/)
297 {
298  size_t i;
299  double d, temp_dist;
300  KdNode temp;
301  searchpredicate = pred;
302 
303  result->clear();
304  if (k < 1)
305  return;
306  if (point.size() != dimension)
307  throw std::invalid_argument("kdtree::k_nearest_neighbors(): point must be of same dimension as "
308  "kdtree");
309 
310  // collect result of k values in neighborheap
311  neighborheap = new std::priority_queue<nn4heap, std::vector<nn4heap>, compare_nn4heap>();
312  if (k > allnodes.size()) {
313  // when more neighbors asked than nodes in tree, return everything
314  k = allnodes.size();
315  for (i = 0; i < k; i++) {
316  if (!(searchpredicate && !(*searchpredicate)(allnodes[i])))
317  neighborheap->push(nn4heap(i, distance->distance(allnodes[i].point, point)));
318  }
319  } else {
320  neighbor_search(point, root, k);
321  }
322 
323  // copy over result sorted by distance
324  // (we must revert the vector for ascending order)
325  while (!neighborheap->empty()) {
326  i = neighborheap->top().dataindex;
327  d = neighborheap->top().distance;
328  neighborheap->pop();
329  result->push_back(allnodes[i]);
330  distances->push_back(d);
331  }
332  // beware that less than k results might have been returned
333  k = result->size();
334  for (i = 0; i < k / 2; i++) {
335  temp = (*result)[i];
336  (*result)[i] = (*result)[k - 1 - i];
337  (*result)[k - 1 - i] = temp;
338  temp_dist = (*distances)[i];
339  (*distances)[i] = (*distances)[k - 1 - i];
340  (*distances)[k - 1 - i] = temp_dist;
341  }
342  delete neighborheap;
343 }
344 
345 //--------------------------------------------------------------
346 // range nearest neighbor search
347 // returns the nearest neighbors of *point* in the given range
348 // *r*. The result is returned in *result* and is sorted by
349 // distance from *point*.
350 //--------------------------------------------------------------
351 void KdTree::range_nearest_neighbors(const CoordPoint &point, double r, KdNodeVector *result)
352 {
353  KdNode temp;
354 
355  result->clear();
356  if (point.size() != dimension)
357  throw std::invalid_argument("kdtree::k_nearest_neighbors(): point must be of same dimension as "
358  "kdtree");
359  if (this->distance_type == 2) {
360  // if euclidien distance is used the range must be squared because we
361  // get squred distances from this implementation
362  r *= r;
363  }
364 
365  // collect result in neighborheap
366  range_search(point, root, r);
367 
368  // copy over result
369  for (std::vector<size_t>::iterator i = range_result.begin(); i != range_result.end(); ++i) {
370  result->push_back(allnodes[*i]);
371  }
372 
373  // clear vector
374  range_result.clear();
375 }
376 
377 //--------------------------------------------------------------
378 // recursive function for nearest neighbor search in subtree
379 // under *node*. Updates the heap (class member) *neighborheap*.
380 // returns "true" when no nearer neighbor elsewhere possible
381 //--------------------------------------------------------------
382 bool KdTree::neighbor_search(const CoordPoint &point, kdtree_node *node, size_t k)
383 {
384  double curdist, dist;
385 
386  curdist = distance->distance(point, node->point);
387  if (!(searchpredicate && !(*searchpredicate)(allnodes[node->dataindex]))) {
388  if (neighborheap->size() < k) {
389  neighborheap->push(nn4heap(node->dataindex, curdist));
390  } else if (curdist < neighborheap->top().distance) {
391  neighborheap->pop();
392  neighborheap->push(nn4heap(node->dataindex, curdist));
393  }
394  }
395  // first search on side closer to point
396  if (point[node->cutdim] < node->point[node->cutdim]) {
397  if (node->loson)
398  if (neighbor_search(point, node->loson, k))
399  return true;
400  } else {
401  if (node->hison)
402  if (neighbor_search(point, node->hison, k))
403  return true;
404  }
405  // second search on farther side, if necessary
406  if (neighborheap->size() < k) {
407  dist = std::numeric_limits<double>::max();
408  } else {
409  dist = neighborheap->top().distance;
410  }
411  if (point[node->cutdim] < node->point[node->cutdim]) {
412  if (node->hison && bounds_overlap_ball(point, dist, node->hison))
413  if (neighbor_search(point, node->hison, k))
414  return true;
415  } else {
416  if (node->loson && bounds_overlap_ball(point, dist, node->loson))
417  if (neighbor_search(point, node->loson, k))
418  return true;
419  }
420 
421  if (neighborheap->size() == k)
422  dist = neighborheap->top().distance;
423  return ball_within_bounds(point, dist, node);
424 }
425 
426 //--------------------------------------------------------------
427 // recursive function for range search in subtree under *node*.
428 // Updates the heap (class member) *neighborheap*.
429 //--------------------------------------------------------------
430 void KdTree::range_search(const CoordPoint &point, kdtree_node *node, double r)
431 {
432  double curdist = distance->distance(point, node->point);
433  if (curdist <= r) {
434  range_result.push_back(node->dataindex);
435  }
436  if (node->loson != NULL && this->bounds_overlap_ball(point, r, node->loson)) {
437  range_search(point, node->loson, r);
438  }
439  if (node->hison != NULL && this->bounds_overlap_ball(point, r, node->hison)) {
440  range_search(point, node->hison, r);
441  }
442 }
443 
444 // returns true when the bounds of *node* overlap with the
445 // ball with radius *dist* around *point*
446 bool KdTree::bounds_overlap_ball(const CoordPoint &point, double dist, kdtree_node *node)
447 {
448  double distsum = 0.0;
449  size_t i;
450  for (i = 0; i < dimension; i++) {
451  if (point[i] < node->lobound[i]) { // lower than low boundary
452  distsum += distance->coordinate_distance(point[i], node->lobound[i], i);
453  if (distsum > dist)
454  return false;
455  } else if (point[i] > node->upbound[i]) { // higher than high boundary
456  distsum += distance->coordinate_distance(point[i], node->upbound[i], i);
457  if (distsum > dist)
458  return false;
459  }
460  }
461  return true;
462 }
463 
464 // returns true when the bounds of *node* completely contain the
465 // ball with radius *dist* around *point*
466 bool KdTree::ball_within_bounds(const CoordPoint &point, double dist, kdtree_node *node)
467 {
468  size_t i;
469  for (i = 0; i < dimension; i++)
470  if (distance->coordinate_distance(point[i], node->lobound[i], i) <= dist ||
471  distance->coordinate_distance(point[i], node->upbound[i], i) <= dist)
472  return false;
473  return true;
474 }
475 
476 } // namespace Kdtree
Kdtree::DistanceL1
Definition: kdtree.cxx:118
Kdtree::KdNodePredicate
Definition: kdtree.hpp:38
Kdtree::nn4heap
Definition: kdtree.hpp:51
Kdtree::DistanceMeasure::coordinate_distance
virtual double coordinate_distance(double x, double y, size_t dim)=0
Kdtree::KdTree::allnodes
KdNodeVector allnodes
Definition: kdtree.hpp:89
Kdtree::compare_dimension::compare_dimension
compare_dimension(size_t dim)
Definition: kdtree.cxx:26
Kdtree::DistanceL1::coordinate_distance
double coordinate_distance(double x, double y, size_t dim)
Definition: kdtree.cxx:147
Kdtree::DoubleVector
std::vector< double > DoubleVector
Definition: kdtree.hpp:19
Kdtree::DistanceL2::distance
double distance(const CoordPoint &p, const CoordPoint &q)
Definition: kdtree.cxx:172
Kdtree::DistanceMeasure::DistanceMeasure
DistanceMeasure()
Definition: kdtree.cxx:66
Kdtree::compare_dimension::d
size_t d
Definition: kdtree.cxx:28
Kdtree::KdTree::range_nearest_neighbors
void range_nearest_neighbors(const CoordPoint &point, double r, KdNodeVector *result)
Definition: kdtree.cxx:351
Kdtree::KdNodeVector
std::vector< KdNode > KdNodeVector
Definition: kdtree.hpp:32
Kdtree
Definition: kdtree.cxx:19
Kdtree::DistanceL0::distance
double distance(const CoordPoint &p, const CoordPoint &q)
Definition: kdtree.cxx:88
Kdtree::KdTree::root
kdtree_node * root
Definition: kdtree.hpp:91
Kdtree::kdtree_node::lobound
CoordPoint lobound
Definition: kdtree.cxx:58
Kdtree::kdtree_node::kdtree_node
kdtree_node()
Definition: kdtree.cxx:36
Kdtree::KdTree::KdTree
KdTree(const KdNodeVector *nodes, int distance_type=2)
Definition: kdtree.cxx:204
node
Definition: fastcluster_dm.cxx:213
Kdtree::DistanceMeasure::distance
virtual double distance(const CoordPoint &p, const CoordPoint &q)=0
Kdtree::DistanceL2::coordinate_distance
double coordinate_distance(double x, double y, size_t dim)
Definition: kdtree.cxx:185
Kdtree::KdNode
Definition: kdtree.hpp:22
Kdtree::CoordPoint
std::vector< double > CoordPoint
Definition: kdtree.hpp:18
Kdtree::DistanceL0
Definition: kdtree.cxx:72
Kdtree::KdTree::k_nearest_neighbors
void k_nearest_neighbors(const CoordPoint &point, size_t k, KdNodeVector *result, std::vector< double > *distances, KdNodePredicate *pred=NULL)
Definition: kdtree.cxx:295
Kdtree::DistanceL0::~DistanceL0
~DistanceL0()
Definition: kdtree.cxx:83
Kdtree::kdtree_node
Definition: kdtree.cxx:34
Kdtree::compare_nn4heap
Definition: kdtree.hpp:61
Kdtree::DistanceL0::DistanceL0
DistanceL0(const DoubleVector *weights=NULL)
Definition: kdtree.cxx:76
Kdtree::DistanceL1::distance
double distance(const CoordPoint &p, const CoordPoint &q)
Definition: kdtree.cxx:134
Kdtree::compare_dimension::operator()
bool operator()(const KdNode &p, const KdNode &q)
Definition: kdtree.cxx:27
Kdtree::kdtree_node::cutdim
size_t cutdim
Definition: kdtree.cxx:51
y
const double * y
Definition: lmcurve.cxx:20
Kdtree::kdtree_node::loson
kdtree_node * loson
Definition: kdtree.cxx:56
Kdtree::DistanceL2::~DistanceL2
~DistanceL2()
Definition: kdtree.cxx:167
Kdtree::kdtree_node::hison
kdtree_node * hison
Definition: kdtree.cxx:56
Kdtree::KdTree::dimension
size_t dimension
Definition: kdtree.hpp:90
Kdtree::DistanceL1::DistanceL1
DistanceL1(const DoubleVector *weights=NULL)
Definition: kdtree.cxx:122
Kdtree::KdNode::point
CoordPoint point
Definition: kdtree.hpp:23
Kdtree::kdtree_node::dataindex
size_t dataindex
Definition: kdtree.cxx:49
Kdtree::KdTree::set_distance
void set_distance(int distance_type, const DoubleVector *weights=NULL)
Definition: kdtree.cxx:232
Kdtree::DistanceL1::~DistanceL1
~DistanceL1()
Definition: kdtree.cxx:129
Kdtree::compare_dimension
Definition: kdtree.cxx:24
Kdtree::DistanceL0::coordinate_distance
double coordinate_distance(double x, double y, size_t dim)
Definition: kdtree.cxx:109
Kdtree::DistanceL2::DistanceL2
DistanceL2(const DoubleVector *weights=NULL)
Definition: kdtree.cxx:160
Kdtree::kdtree_node::~kdtree_node
~kdtree_node()
Definition: kdtree.cxx:41
Kdtree::kdtree_node::upbound
CoordPoint upbound
Definition: kdtree.cxx:58
Kdtree::DistanceMeasure::~DistanceMeasure
virtual ~DistanceMeasure()
Definition: kdtree.cxx:67
Kdtree::DistanceL2
Definition: kdtree.cxx:156
Kdtree::KdTree::~KdTree
~KdTree()
Definition: kdtree.cxx:197
kdtree.hpp
Kdtree::kdtree_node::point
CoordPoint point
Definition: kdtree.cxx:54
Kdtree::DistanceMeasure
Definition: kdtree.cxx:64