K-d tree

From Rosetta Code
Revision as of 23:44, 17 April 2012 by rosettacode>Bearophile (Templated coordinate type too in D entry)
K-d tree is a draft programming task. It is not yet considered ready to be promoted as a complete task, for reasons that should be found in its talk page.
This page uses content from Wikipedia. The original article was at K-d tree. The list of authors can be seen in the page history. As with Rosetta Code, the text of Wikipedia is available under the GNU FDL. (See links for details on variance)

A k-d tree (short for k-dimensional tree) is a space-partitioning data structure for organizing points in a k-dimensional space. k-d trees are a useful data structure for several applications, such as searches involving a multidimensional search key (e.g. range searches and nearest neighbor searches). k-d trees are a special case of binary space partitioning trees.

k-d trees are not suitable, however, for efficiently finding the nearest neighbor in high dimensional spaces. As a general rule, if the dimensionality is k, the number of points in the data, N, should be N >> 2k. Otherwise, when k-d trees are used with high-dimensional data, most of the points in the tree will be evaluated and the efficiency is no better than exhaustive search, and other methods such as approximate nearest-neighbor are used instead.

Task: Construct a k-d tree and perform a nearest neighbor search for two example data sets:

  1. The Wikipedia example data of [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)].
  2. 1000 3-d points uniformly distributed in a 3-d cube.

For the Wikipedia example, find the nearest neighbor to point (9, 2) For the random data, pick a random location and find the nearest neighbor.

In addition, instrument your code to count the number of nodes visited in the nearest neighbor search. Count a node as visited if any field of it is accessed.

Output should show the point searched for, the point found, the distance to the point, and the number of nodes visited.

There are variant algorithms for constructing the tree. You can use a simple median strategy or implement something more efficient. Variants of the nearest neighbor search include nearest N neighbors, approximate nearest neighbor, and range searches. You do not have to implement these. The requirement for this task is specifically the nearest single neighbor. Also there are algorithms for inserting, deleting, and balancing k-d trees. These are also not required for the task.

D

Translation of: Go

The code is templated on the the dimensionality of the points. Points are values. <lang d>// Implmentation following pseudocode from // "An intoductory tutorial on kd-trees" by Andrew W. Moore, // Carnegie Mellon University, PDF accessed from: // http://www.autonlab.org/autonweb/14665

import std.typecons, std.math, std.algorithm, std.random, std.range,

      std.typetuple, std.traits;

/// k-dimensional point. struct Point(size_t k, F) if (isFloatingPoint!F) {

   F[k] data;
   alias data this;
   /// Square of the euclidean distance.
   double sqd(const ref Point!(k,F) q) pure nothrow {
       double sum = 0;
       foreach (dim, pCoord; data)
           sum += (pCoord - q[dim]) ^^ 2;
       return sum;
   }

}

// Following field names in the paper. // rangeElt would be whatever data is associated with the Point. // We don't bother with it for this example. struct KdNode(size_t k, F) {

   Point!(k,F) domElt;
   int split;
   KdNode!(k,F)* left, right;
   // This naive ctor is currently necessary
   this(in Point!(k,F) domElt_, in int split_,
        KdNode!(k,F)* left_, KdNode!(k,F)* right_) pure nothrow {
       this.domElt = domElt_;
       this.split = split_;
       this.left = left_;
       this.right = right_;
   }

}

struct Orthotope(size_t k, F) { /// k-dimensional rectangle.

   Point!(k,F) min, max;

}

struct KdTree(size_t k, F) {

   KdNode!(k,F)* n;
   Orthotope!(k,F) bounds;
   // Constructs a KdTree from a list of points, also associating the
   // bounds of the tree. The bounds could be computed of course, but
   // in this example we know them already. The algorithm is table
   // 6.3 in the paper.
   this(Point!(k,F)[] pts, in Orthotope!(k,F) bounds_) {
       static KdNode!(k,F)* nk2(size_t split)(Point!(k,F)[] exset) {
           if (exset.empty)
               return null;
           // Pivot choosing procedure. We find median, then find
           // largest index of points with median value. This
           // satisfies the inequalities of steps 6 and 7 in the
           // algorithm.
           //sort!((a,b) => a[split] < b[split])(exset);
           sort!((Point!(k,F) a, Point!(k,F) b)=> a[split] < b[split])
                (exset);
           auto m = exset.length / 2;
           immutable d = exset[m];
           while (m+1 < exset.length && exset[m+1][split] == d[split])
               m++;
           // next split
           enum s2 = (split + 1 == d.length) ? 0 : split + 1;
           return new KdNode!(k,F)(d, split,
                                   nk2!s2(exset[0 .. m]),
                                   nk2!s2(exset[m+1 .. $]));
       }
       this.n = nk2!0(pts);
       this.bounds = bounds_;
   }

}

/** Find nearest neighbor. Return values are:

 nearest neighbor--the ooint within the tree that is nearest p.
 square of the distance to that point.
 a count of the nodes visited in the search.
  • /

auto findNearest(size_t k, F)(KdTree!(k,F) t, in Point!(k,F) p) pure nothrow {

   // Algorithm is table 6.4 from the paper, with the addition of
   // counting the number nodes visited.
   static Tuple!(Point!(k,F),"nearest", double,"distSqd",
                 int,"nodesVisited")
          nn(KdNode!(k,F)* kd, in Point!(k,F) target,
             Orthotope!(k,F) hr, double maxDistSqd) pure nothrow {
       if (kd == null)
           return typeof(return)(Point!(k,F)(), double.infinity, 0);
       int nodesVisited = 1;
       immutable s = kd.split;
       auto pivot = kd.domElt;
       auto leftHr = hr;
       auto rightHr = hr;
       leftHr.max[s] = pivot[s];
       rightHr.min[s] = pivot[s];
       KdNode!(k,F)* nearerKd, furtherKd;
       Orthotope!(k,F) nearerHr, furtherHr;
       if (target[s] <= pivot[s]) {
           //nearerKd, nearerHr = kd.left, leftHr
           //furtherKd, furtherHr = kd.right, rightHr
           nearerKd = kd.left;
           nearerHr = leftHr;
           furtherKd = kd.right;
           furtherHr = rightHr;
       } else {
           //nearerKd, nearerHr = kd.right, rightHr
           //furtherKd, furtherHr = kd.left, leftHr
           nearerKd = kd.right;
           nearerHr = rightHr;
           furtherKd = kd.left;
           furtherHr = leftHr;
       }
       auto n1 = nn(nearerKd, target, nearerHr, maxDistSqd);
       auto nearest = n1.nearest;
       auto distSqd = n1.distSqd;
       nodesVisited += n1.nodesVisited;
       if (distSqd < maxDistSqd)
           maxDistSqd = distSqd;
       auto d = (pivot[s] - target[s]) ^^ 2;
       if (d > maxDistSqd)
           return typeof(return)(nearest, distSqd, nodesVisited);
       d = pivot.sqd(target);
       if (d < distSqd) {
           nearest = pivot;
           distSqd = d;
           maxDistSqd = distSqd;
       }
       immutable n2 = nn(furtherKd, target, furtherHr, maxDistSqd);
       nodesVisited += n2.nodesVisited;
       if (n2.distSqd < distSqd) {
           nearest = n2.nearest;
           distSqd = n2.distSqd;
       }
       return typeof(return)(nearest, distSqd, nodesVisited);
   }
   return nn(t.n, p, t.bounds, double.infinity);

}

void showNearest(size_t k, F)(in string heading, KdTree!(k,F) kd,

                             in Point!(k,F) p) {
   import std.stdio;
   writeln(heading, ":");
   writeln("Point:            ", p);
   immutable n = kd.findNearest(p);
   writeln("Nearest neighbor: ", n.nearest);
   writeln("Distance:         ", sqrt(n.distSqd));
   writeln("Nodes visited:    ", n.nodesVisited, "\n");

}

Point!(k,F) randomPoint(size_t k, F)() {

   F[k] coords = iota(k)
                 .map!(_ => uniform(cast(F)0, cast(F)1))()
                 .array();
   return Point!(k,F)(coords);

}

Point!(k,F)[] randomPoints(size_t k, F)(int n) {

   return iota(n).map!(_ => randomPoint!(k,F)())().array();

}

void main() {

   rndGen.seed(1); // For repeatable outputs.
   alias TypeTuple!(2, double) D2;
   alias Point!D2 P;
   auto kd1 = KdTree!D2([P([2, 3]), P([5, 4]), P([9, 6]),
                         P([4, 7]), P([8, 1]), P([7, 2])],
                        Orthotope!D2(P([0, 0]), P([10, 10])));
   showNearest("Wikipedia example data", kd1, P([9, 2]));
   import std.conv, std.datetime;
   enum int N = 400_000;
   alias TypeTuple!(3, float) F3;
   alias Point!F3 Q;
   StopWatch sw;
   sw.start();
   auto kd2 = KdTree!F3(randomPoints!F3(N),
                        Orthotope!F3(Q([0, 0, 0]), Q([1, 1, 1])));
   sw.stop();
   showNearest(text("k-d tree with ", N,
                    " random 3D points (generation time: ",
                    sw.peek().msecs, "ms)"), kd2, randomPoint!F3());

}</lang>

Output:
Wikipedia example data:
Point:            [9, 2]
Nearest neighbor: [8, 1]
Distance:         1.41421
Nodes visited:    3

k-d tree with 400000 random 3D points (generation time: 3694ms):
Point:            [0.22012, 0.984514, 0.698782]
Nearest neighbor: [0.225766, 0.978981, 0.69885]
Distance:         0.00790531
Nodes visited:    54

Go

<lang go>// Implmentation following pseudocode from "An intoductory tutorial on kd-trees" // by Andrew W. Moore, Carnegie Mellon University, PDF accessed from // http://www.autonlab.org/autonweb/14665 package main

import (

   "fmt"
   "math"
   "math/rand"
   "sort"
   "time"

)

// point is a k-dimensional point. type point []float64

// sqd returns the square of the euclidean distance. func (p point) sqd(q point) float64 {

   var sum float64
   for dim, pCoord := range p {
       d := pCoord - q[dim]
       sum += d * d
   }
   return sum

}

// kdNode following field names in the paper. // rangeElt would be whatever data is associated with the point. we don't // bother with it for this example. type kdNode struct {

   domElt      point
   split       int
   left, right *kdNode

}

type kdTree struct {

   n      *kdNode
   bounds hyperRect

}

type hyperRect struct {

   min, max point

}

// Go slices are reference objects. The data must be copied if you want // to modify one without modifying the original. func (hr hyperRect) copy() hyperRect {

   return hyperRect{append(point{}, hr.min...), append(point{}, hr.max...)}

}

// newKd constructs a kdTree from a list of points, also associating the // bounds of the tree. The bounds could be computed of course, but in this // example we know them already. The algorithm is table 6.3 in the paper. func newKd(pts []point, bounds hyperRect) kdTree {

   var nk2 func([]point, int) *kdNode
   nk2 = func(exset []point, split int) *kdNode {
       if len(exset) == 0 {
           return nil
       }
       // pivot choosing procedure.  we find median, then find largest
       // index of points with median value.  this satisfies the
       // inequalities of steps 6 and 7 in the algorithm.
       sort.Sort(part{exset, split})
       m := len(exset) / 2
       d := exset[m]
       for m+1 < len(exset) && exset[m+1][split] == d[split] {
           m++
       }
       // next split
       s2 := split + 1
       if s2 == len(d) {
           s2 = 0
       }
       return &kdNode{d, split, nk2(exset[:m], s2), nk2(exset[m+1:], s2)}
   }
   return kdTree{nk2(pts, 0), bounds}

}

// a container type used for sorting. it holds the points to sort and // the dimension to use for the sort key. type part struct {

   pts   []point
   dPart int

}

// satisfy sort.Interface func (p part) Len() int { return len(p.pts) } func (p part) Less(i, j int) bool {

   return p.pts[i][p.dPart] < p.pts[j][p.dPart]

} func (p part) Swap(i, j int) { p.pts[i], p.pts[j] = p.pts[j], p.pts[i] }

// nearest. find nearest neighbor. return values are: // nearest neighbor--the point within the tree that is nearest p. // square of the distance to that point. // a count of the nodes visited in the search. func (t kdTree) nearest(p point) (best point, bestSqd float64, nv int) {

   return nn(t.n, p, t.bounds, math.Inf(1))

}

// algorithm is table 6.4 from the paper, with the addition of counting // the number nodes visited. func nn(kd *kdNode, target point, hr hyperRect,

   maxDistSqd float64) (nearest point, distSqd float64, nodesVisited int) {
   if kd == nil {
       return nil, math.Inf(1), 0
   }
   nodesVisited++
   s := kd.split
   pivot := kd.domElt
   leftHr := hr.copy()
   rightHr := hr.copy()
   leftHr.max[s] = pivot[s]
   rightHr.min[s] = pivot[s]
   targetInLeft := target[s] <= pivot[s]
   var nearerKd, furtherKd *kdNode
   var nearerHr, furtherHr hyperRect
   if targetInLeft {
       nearerKd, nearerHr = kd.left, leftHr
       furtherKd, furtherHr = kd.right, rightHr
   } else {
       nearerKd, nearerHr = kd.right, rightHr
       furtherKd, furtherHr = kd.left, leftHr
   }
   var nv int
   nearest, distSqd, nv = nn(nearerKd, target, nearerHr, maxDistSqd)
   nodesVisited += nv
   if distSqd < maxDistSqd {
       maxDistSqd = distSqd
   }
   d := pivot[s] - target[s]
   d *= d
   if d > maxDistSqd {
       return
   }
   if d = pivot.sqd(target); d < distSqd {
       nearest = pivot
       distSqd = d
       maxDistSqd = distSqd
   }
   tempNearest, tempSqd, nv := nn(furtherKd, target, furtherHr, maxDistSqd)
   nodesVisited += nv
   if tempSqd < distSqd {
       nearest = tempNearest
       distSqd = tempSqd
   }
   return

}

func main() {

   rand.Seed(time.Now().Unix())
   kd := newKd([]point{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}},
       hyperRect{point{0, 0}, point{10, 10}})
   showNearest("WP example data", kd, point{9, 2})
   kd = newKd(randomPts(3, 1000), hyperRect{point{0, 0, 0}, point{1, 1, 1}})
   showNearest("1000 random 3d points", kd, randomPt(3))

}

func randomPt(dim int) point {

   p := make(point, dim)
   for d := range p {
       p[d] = rand.Float64()
   }
   return p

}

func randomPts(dim, n int) []point {

   p := make([]point, n)
   for i := range p {
       p[i] = randomPt(dim) 
   } 
   return p

}

func showNearest(heading string, kd kdTree, p point) {

   fmt.Println()
   fmt.Println(heading)
   fmt.Println("point:           ", p)
   nn, ssq, nv := kd.nearest(p)
   fmt.Println("nearest neighbor:", nn)
   fmt.Println("distance:        ", math.Sqrt(ssq))
   fmt.Println("nodes visited:   ", nv)

}</lang>

Output:
WP example data
point:            [9 2]
nearest neighbor: [8 1]
distance:         1.4142135623730951
nodes visited:    3

1000 random 3d points
point:            [0.314731890562714 0.5908890147906868 0.2657722255021785]
nearest neighbor: [0.2541611609533609 0.5781168738628141 0.27829000365095274]
distance:         0.06315564613771865
nodes visited:    25