K-d tree

From Rosetta Code
Revision as of 17:28, 6 June 2014 by Rdm (talk | contribs) (J: fix a bug in node visit counting, and redo timings)
Task
K-d tree
You are encouraged to solve this task according to the task description, using any language you may know.
This page uses content from Wikipedia. The original article was at K-d tree. The list of authors can be seen in the page history. As with Rosetta Code, the text of Wikipedia is available under the GNU FDL. (See links for details on variance)

A k-d tree (short for k-dimensional tree) is a space-partitioning data structure for organizing points in a k-dimensional space. k-d trees are a useful data structure for several applications, such as searches involving a multidimensional search key (e.g. range searches and nearest neighbor searches). k-d trees are a special case of binary space partitioning trees.

k-d trees are not suitable, however, for efficiently finding the nearest neighbor in high dimensional spaces. As a general rule, if the dimensionality is k, the number of points in the data, N, should be N ≫ 2k. Otherwise, when k-d trees are used with high-dimensional data, most of the points in the tree will be evaluated and the efficiency is no better than exhaustive search, and other methods such as approximate nearest-neighbor are used instead.

Task: Construct a k-d tree and perform a nearest neighbor search for two example data sets:

  1. The Wikipedia example data of [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)].
  2. 1000 3-d points uniformly distributed in a 3-d cube.

For the Wikipedia example, find the nearest neighbor to point (9, 2) For the random data, pick a random location and find the nearest neighbor.

In addition, instrument your code to count the number of nodes visited in the nearest neighbor search. Count a node as visited if any field of it is accessed.

Output should show the point searched for, the point found, the distance to the point, and the number of nodes visited.

There are variant algorithms for constructing the tree. You can use a simple median strategy or implement something more efficient. Variants of the nearest neighbor search include nearest N neighbors, approximate nearest neighbor, and range searches. You do not have to implement these. The requirement for this task is specifically the nearest single neighbor. Also there are algorithms for inserting, deleting, and balancing k-d trees. These are also not required for the task.

C

Using a Quickselectesque median algorithm. Compared to unbalanced trees (random insertion), it takes slightly longer (maybe half a second or so) to construct a million-node tree, though average look up visits about 1/3 fewer nodes. <lang c>#include <stdio.h>

  1. include <stdlib.h>
  2. include <string.h>
  3. include <math.h>
  4. include <time.h>
  1. define MAX_DIM 3

struct kd_node_t{ double x[MAX_DIM]; struct kd_node_t *left, *right; };

inline double dist(struct kd_node_t *a, struct kd_node_t *b, int dim) { double t, d = 0; while (dim--) { t = a->x[dim] - b->x[dim]; d += t * t; } return d; }

/* see quickselect method */ struct kd_node_t* find_median(struct kd_node_t *start, struct kd_node_t *end, int idx) { if (end <= start) return NULL; if (end == start + 1) return start;

inline void swap(struct kd_node_t *x, struct kd_node_t *y) { double tmp[MAX_DIM]; memcpy(tmp, x->x, sizeof(tmp)); memcpy(x->x, y->x, sizeof(tmp)); memcpy(y->x, tmp, sizeof(tmp)); }

struct kd_node_t *p, *store, *md = start + (end - start) / 2; double pivot; while (1) { pivot = md->x[idx];

swap(md, end - 1); for (store = p = start; p < end; p++) { if (p->x[idx] < pivot) { if (p != store) swap(p, store); store++; } } swap(store, end - 1);

/* median has duplicate values */ if (store->x[idx] == md->x[idx]) return md;

if (store > md) end = store; else start = store; } }

struct kd_node_t* make_tree(struct kd_node_t *t, int len, int i, int dim) { struct kd_node_t *n;

if (!len) return 0;

if ((n = find_median(t, t + len, i))) { i = (i + 1) % dim; n->left = make_tree(t, n - t, i, dim); n->right = make_tree(n + 1, t + len - (n + 1), i, dim); } return n; }

/* global variable, so sue me */ int visited;

void nearest(struct kd_node_t *root, struct kd_node_t *nd, int i, int dim, struct kd_node_t **best, double *best_dist) { double d, dx, dx2;

if (!root) return; d = dist(root, nd, dim); dx = root->x[i] - nd->x[i]; dx2 = dx * dx;

visited ++;

if (!*best || d < *best_dist) { *best_dist = d; *best = root; }

/* if chance of exact match is high */ if (!*best_dist) return;

if (++i >= dim) i = 0;

nearest(dx > 0 ? root->left : root->right, nd, i, dim, best, best_dist); if (dx2 >= *best_dist) return; nearest(dx > 0 ? root->right : root->left, nd, i, dim, best, best_dist); }

  1. define N 1000000
  2. define rand1() (rand() / (double)RAND_MAX)
  3. define rand_pt(v) { v.x[0] = rand1(); v.x[1] = rand1(); v.x[2] = rand1(); }

int main(void) { int i; struct kd_node_t wp[] = { Template:2, 3, Template:5, 4, Template:9, 6, Template:4, 7, Template:8, 1, Template:7, 2 }; struct kd_node_t this = Template:9, 2; struct kd_node_t *root, *found, *million; double best_dist;

root = make_tree(wp, sizeof(wp) / sizeof(wp[1]), 0, 2);

visited = 0; found = 0; nearest(root, &this, 0, 2, &found, &best_dist);

printf(">> WP tree\nsearching for (%g, %g)\n" "found (%g, %g) dist %g\nseen %d nodes\n\n", this.x[0], this.x[1], found->x[0], found->x[1], sqrt(best_dist), visited);

million = calloc(N, sizeof(struct kd_node_t)); srand(time(0)); for (i = 0; i < N; i++) rand_pt(million[i]);

root = make_tree(million, N, 0, 3); rand_pt(this);

visited = 0; found = 0; nearest(root, &this, 0, 3, &found, &best_dist);

printf(">> Million tree\nsearching for (%g, %g, %g)\n" "found (%g, %g, %g) dist %g\nseen %d nodes\n", this.x[0], this.x[1], this.x[2], found->x[0], found->x[1], found->x[2], sqrt(best_dist), visited);

/* search many random points in million tree to see average behavior. tree size vs avg nodes visited: 10 ~ 7 100 ~ 16.5 1000 ~ 25.5 10000 ~ 32.8 100000 ~ 38.3 1000000 ~ 42.6 10000000 ~ 46.7 */ int sum = 0, test_runs = 100000; for (i = 0; i < test_runs; i++) { found = 0; visited = 0; rand_pt(this); nearest(root, &this, 0, 3, &found, &best_dist); sum += visited; } printf("\n>> Million tree\n" "visited %d nodes for %d random findings (%f per lookup)\n", sum, test_runs, sum/(double)test_runs);

// free(million);

return 0;

}</lang>output

>> WP tree
searching for (9, 2)
found (8, 1) dist 1.41421
seen 3 nodes

>> Million tree
searching for (0.29514, 0.897237, 0.941998)
found (0.296093, 0.896173, 0.948082) dist 0.00624896
seen 44 nodes

>> Million tree
visited 4271442 nodes for 100000 random findings (42.714420 per lookup)

D

Translation of: Go

Points are values, the code is templated on the the dimensionality of the points and the floating point type of the coordinate. Instead of sorting it uses the faster topN, that partitions the points array in two halves around their median. <lang d>// Implmentation following pseudocode from // "An introductory tutorial on kd-trees" by Andrew W. Moore, // Carnegie Mellon University, PDF accessed from: // http://www.autonlab.org/autonweb/14665

import std.typecons, std.math, std.algorithm, std.random, std.range,

      std.traits, core.memory;

/// k-dimensional point. struct Point(size_t k, F) if (isFloatingPoint!F) {

   F[k] data;
   alias data this; // Kills DMD std.algorithm.swap inlining.
                    // Define opIndexAssign and opIndex for dmd.
   enum size_t length = k;
   /// Square of the euclidean distance.
   double sqd(in ref Point!(k, F) q) const pure nothrow @nogc {
       double sum = 0;
       foreach (immutable dim, immutable pCoord; data)
           sum += (pCoord - q[dim]) ^^ 2;
       return sum;
   }

}

// Following field names in the paper. // rangeElt would be whatever data is associated with the Point. // We don't bother with it for this example. struct KdNode(size_t k, F) {

   Point!(k, F) domElt;
   immutable int split;
   typeof(this)* left, right;

}

struct Orthotope(size_t k, F) { /// k-dimensional rectangle.

   Point!(k, F) min, max;

}

struct KdTree(size_t k, F) {

   KdNode!(k, F)* n;
   Orthotope!(k, F) bounds;
   // Constructs a KdTree from a list of points, also associating the
   // bounds of the tree. The bounds could be computed of course, but
   // in this example we know them already. The algorithm is table
   // 6.3 in the paper.
   this(Point!(k, F)[] pts, in Orthotope!(k, F) bounds_) pure {
       static KdNode!(k, F)* nk2(size_t split)(Point!(k, F)[] exset)
       pure {
           if (exset.empty)
               return null;
           if (exset.length == 1)
               return new KdNode!(k, F)(exset[0], split, null, null);
           // Pivot choosing procedure. We find median, then find
           // largest index of points with median value. This
           // satisfies the inequalities of steps 6 and 7 in the
           // algorithm.
           auto m = exset.length / 2;
           topN!((p, q) => p[split] < q[split])(exset, m);
           immutable d = exset[m];
           while (m+1 < exset.length && exset[m+1][split] == d[split])
               m++;
           enum nextSplit = (split + 1) % d.length;//cycle coordinates
           return new KdNode!(k, F)(d, split,
                                    nk2!nextSplit(exset[0 .. m]),
                                    nk2!nextSplit(exset[m + 1 .. $]));
       }
       this.n = nk2!0(pts);
       this.bounds = bounds_;
   }

}

/** Find nearest neighbor. Return values are:

 nearest neighbor--the ooint within the tree that is nearest p.
 square of the distance to that point.
 a count of the nodes visited in the search.
  • /

auto findNearest(size_t k, F)(KdTree!(k, F) t, in Point!(k, F) p) pure nothrow @nogc {

   // Algorithm is table 6.4 from the paper, with the addition of
   // counting the number nodes visited.
   static Tuple!(Point!(k, F), "nearest",
                 F, "distSqd",
                 int, "nodesVisited")
          nn(KdNode!(k, F)* kd, in Point!(k, F) target,
             Orthotope!(k, F) hr, F maxDistSqd) pure nothrow @nogc {
       if (kd == null)
           return typeof(return)(Point!(k, F)(), F.infinity, 0);
       int nodesVisited = 1;
       immutable s = kd.split;
       auto pivot = kd.domElt;
       auto leftHr = hr;
       auto rightHr = hr;
       leftHr.max[s] = pivot[s];
       rightHr.min[s] = pivot[s];
       KdNode!(k, F)* nearerKd, furtherKd;
       Orthotope!(k, F) nearerHr, furtherHr;
       if (target[s] <= pivot[s]) {
           //nearerKd, nearerHr = kd.left, leftHr;
           //furtherKd, furtherHr = kd.right, rightHr;
           nearerKd = kd.left;
           nearerHr = leftHr;
           furtherKd = kd.right;
           furtherHr = rightHr;
       } else {
           //nearerKd, nearerHr = kd.right, rightHr;
           //furtherKd, furtherHr = kd.left, leftHr;
           nearerKd = kd.right;
           nearerHr = rightHr;
           furtherKd = kd.left;
           furtherHr = leftHr;
       }
       auto n1 = nn(nearerKd, target, nearerHr, maxDistSqd);
       auto nearest = n1.nearest;
       auto distSqd = n1.distSqd;
       nodesVisited += n1.nodesVisited;
       if (distSqd < maxDistSqd)
           maxDistSqd = distSqd;
       auto d = (pivot[s] - target[s]) ^^ 2;
       if (d > maxDistSqd)
           return typeof(return)(nearest, distSqd, nodesVisited);
       d = pivot.sqd(target);
       if (d < distSqd) {
           nearest = pivot;
           distSqd = d;
           maxDistSqd = distSqd;
       }
       immutable n2 = nn(furtherKd, target, furtherHr, maxDistSqd);
       nodesVisited += n2.nodesVisited;
       if (n2.distSqd < distSqd) {
           nearest = n2.nearest;
           distSqd = n2.distSqd;
       }
       return typeof(return)(nearest, distSqd, nodesVisited);
   }
   return nn(t.n, p, t.bounds, F.infinity);

}

void showNearest(size_t k, F)(in string heading, KdTree!(k, F) kd,

                             in Point!(k, F) p) {
   import std.stdio: writeln;
   writeln(heading, ":");
   writeln("Point:            ", p);
   immutable n = kd.findNearest(p);
   writeln("Nearest neighbor: ", n.nearest);
   writeln("Distance:         ", sqrt(n.distSqd));
   writeln("Nodes visited:    ", n.nodesVisited, "\n");

}

void main() {

   static Point!(k, F) randomPoint(size_t k, F)() {
       typeof(return) result;
       foreach (immutable i; 0 .. k)
           result[i] = uniform(F(0), F(1));
       return result;
   }
   static Point!(k, F)[] randomPoints(size_t k, F)(in size_t n) {
       return n.iota.map!(_ => randomPoint!(k, F)).array;
   }
   import std.stdio, std.conv, std.datetime, std.typetuple;
   rndGen.seed(1); // For repeatable outputs.
   alias D2 = TypeTuple!(2, double);
   alias P = Point!D2;
   auto kd1 = KdTree!D2([P([2, 3]), P([5, 4]), P([9, 6]),
                         P([4, 7]), P([8, 1]), P([7, 2])],
                        Orthotope!D2(P([0, 0]), P([10, 10])));
   showNearest("Wikipedia example data", kd1, P([9, 2]));
   enum int N = 400_000;
   alias F3 = TypeTuple!(3, float);
   alias Q = Point!F3;
   StopWatch sw;
   GC.disable;
   sw.start;
   auto kd2 = KdTree!F3(randomPoints!F3(N),
                        Orthotope!F3(Q([0, 0, 0]), Q([1, 1, 1])));
   sw.stop;
   GC.enable;
   showNearest(text("k-d tree with ", N,
                    " random 3D ", F3[1].stringof,
                    " points (construction time: ",
                    sw.peek.msecs, " ms)"), kd2, randomPoint!F3);
   sw.reset;
   sw.start;
   enum int M = 10_000;
   size_t visited = 0;
   foreach (immutable _; 0 .. M) {
       immutable n = kd2.findNearest(randomPoint!F3);
       visited += n.nodesVisited;
   }
   sw.stop;
   writefln("Visited an average of %0.2f nodes on %d searches " ~
            "in %d ms.", visited / double(M), M, sw.peek.msecs);

}</lang>

Output, using the ldc2 compiler:
Wikipedia example data:
Point:            [9, 2]
Nearest neighbor: [8, 1]
Distance:         1.41421
Nodes visited:    3

k-d tree with 400000 random 3D float points (construction time: 250 ms):
Point:            [0.22012, 0.984514, 0.698782]
Nearest neighbor: [0.225766, 0.978981, 0.69885]
Distance:         0.00790531
Nodes visited:    54

Visited an average of 43.10 nodes on 10000 searches in 33 ms.

Faster Alternative Version

Translation of: C

This version performs less lookups. Compiled with DMD this version is two times slower than the C version. Compiled with ldc2 it's a little faster than the C version compiled with gcc. <lang d>import std.stdio, std.algorithm, std.math, std.random, std.typetuple;

enum maxDim = 3;

template Iota(int stop) {

   static if (stop <= 0)
       alias Iota = TypeTuple!();
   else
       alias Iota = TypeTuple!(Iota!(stop - 1), stop - 1);

}

struct KdNode {

   double[maxDim] x;
   KdNode* left, right;

}

// See QuickSelect method. KdNode* findMedian(size_t idx)(KdNode[] nodes) pure nothrow @nogc {

   auto start = nodes.ptr;
   auto end = &nodes[$ - 1] + 1;
   if (end <= start)
       return null;
   if (end == start + 1)
       return start;
   KdNode* md = start + (end - start) / 2;
   while (true) {
       immutable double pivot = md.x[idx];
       swap(md.x, (end - 1).x); // Swaps the whole arrays x.
       auto store = start;
       foreach (p; start .. end) {
           if (p.x[idx] < pivot) {
               if (p != store)
                   swap(p.x, store.x);
               store++;
           }
       }
       swap(store.x, (end - 1).x);
       // Median has duplicate values.
       if (store.x[idx] == md.x[idx])
           return md;
       if (store > md)
           end = store;
       else
           start = store;
   }

}

KdNode* makeTree(size_t dim, size_t i)(KdNode[] nodes) pure nothrow @nogc {

   if (!nodes.length)
       return null;
   auto n = findMedian!i(nodes);
   if (n != null) {
       enum i2 = (i + 1) % dim;
       immutable size_t nPos = n - nodes.ptr;
       n.left  = makeTree!(dim, i2)(nodes[0 .. nPos]);
       n.right = makeTree!(dim, i2)(nodes[nPos + 1 .. $]);
   }
   return n;

}

void nearest(size_t dim)(in KdNode* root,

                        in ref KdNode nd,
                        in size_t i,
                        ref const(KdNode)* best,
                        ref double bestDist,
                        ref size_t nVisited) pure nothrow @nogc {
   static double dist(in ref KdNode a, in ref KdNode b)
   pure nothrow @nogc {
       typeof(KdNode.x[0]) result = 0;
       foreach (i; Iota!dim)
           result += (a.x[i] - b.x[i]) ^^ 2;
       return result;
   }
   if (root == null)
       return;
   immutable double d = dist(*root, nd);
   immutable double dx = root.x[i] - nd.x[i];
   immutable double dx2 = dx ^^ 2;
   nVisited++;
   if (!best || d < bestDist) {
       bestDist = d;
       best = root;
   }
   // If chance of exact match is high.
   if (!bestDist)
       return;
   immutable i2 = (i + 1 >= dim) ? 0 : i + 1;
   nearest!dim(dx > 0 ? root.left : root.right,
               nd, i2, best, bestDist, nVisited);
   if (dx2 >= bestDist)
       return;
   nearest!dim(dx > 0 ? root.right : root.left,
               nd, i2, best, bestDist, nVisited);

}

void randPt(size_t dim=3)(ref KdNode v, ref Xorshift rng) nothrow @nogc {

   foreach (immutable i; Iota!dim)
       v.x[i] = rng.uniform01;

}

void smallTest() {

   KdNode[] wp = [{[2, 3]}, {[5, 4]}, {[9, 6]},
                  {[4, 7]}, {[8, 1]}, {[7, 2]}];
   KdNode thisPt = {[9, 2]};
   KdNode* root = makeTree!(2, 0)(wp);
   const(KdNode)* found = null;
   double bestDist = 0;
   size_t nVisited = 0;
   nearest!2(root, thisPt, 0, found, bestDist, nVisited);
   writefln("WP tree:\n  Searching for %s\n" ~
            "  Found %s, dist = %g\n  Seen %d nodes.\n",
            thisPt.x[0..2], found.x[0..2], sqrt(bestDist), nVisited);

}

void bigTest() {

   enum N = 1_000_000;
   enum testRuns = 100_000;
   auto bigTree = new KdNode[N];
   auto rng = Xorshift(1);
   foreach (ref node; bigTree)
       randPt(node, rng);
   KdNode* root = makeTree!(3, 0)(bigTree);
   KdNode thisPt;
   randPt(thisPt, rng);
   const(KdNode)* found = null;
   double bestDist = 0;
   size_t nVisited = 0;
   nearest!3(root, thisPt, 0, found, bestDist, nVisited);
   writefln("Big tree (%d nodes):\n  Searching for %s\n"~
            "  Found %s, dist = %g\n  Seen %d nodes.",
            N, thisPt.x, found.x, sqrt(bestDist), nVisited);
   size_t sum = 0;
   foreach (immutable _; 0 .. testRuns) {
       found = null;
       nVisited = 0;
       randPt(thisPt, rng);
       nearest!3(root, thisPt, 0, found, bestDist, nVisited);
       sum += nVisited;
   }
   writefln("\nBig tree:\n  Visited %d nodes for %d random "~
            "searches (%.2f per lookup).",
            sum, testRuns, sum / double(testRuns));

}

void main() {

   smallTest;
   bigTest;

}</lang>

Output:
WP tree:
  Searching for [9, 2]
  Found [8, 1], dist = 1.41421
  Seen 3 nodes.

Big tree (1000000 nodes):
  Searching for [0.225893, 0.725471, 0.486279]
  Found [0.220761, 0.729613, 0.489134], dist = 0.00718703
  Seen 35 nodes.

Big tree:
  Visited 4267592 nodes for 100000 random searches (42.68 per lookup).

Go

<lang go>// Implmentation following pseudocode from "An intoductory tutorial on kd-trees" // by Andrew W. Moore, Carnegie Mellon University, PDF accessed from // http://www.autonlab.org/autonweb/14665 package main

import (

   "fmt"
   "math"
   "math/rand"
   "sort"
   "time"

)

// point is a k-dimensional point. type point []float64

// sqd returns the square of the euclidean distance. func (p point) sqd(q point) float64 {

   var sum float64
   for dim, pCoord := range p {
       d := pCoord - q[dim]
       sum += d * d
   }
   return sum

}

// kdNode following field names in the paper. // rangeElt would be whatever data is associated with the point. we don't // bother with it for this example. type kdNode struct {

   domElt      point
   split       int
   left, right *kdNode

}

type kdTree struct {

   n      *kdNode
   bounds hyperRect

}

type hyperRect struct {

   min, max point

}

// Go slices are reference objects. The data must be copied if you want // to modify one without modifying the original. func (hr hyperRect) copy() hyperRect {

   return hyperRect{append(point{}, hr.min...), append(point{}, hr.max...)}

}

// newKd constructs a kdTree from a list of points, also associating the // bounds of the tree. The bounds could be computed of course, but in this // example we know them already. The algorithm is table 6.3 in the paper. func newKd(pts []point, bounds hyperRect) kdTree {

   var nk2 func([]point, int) *kdNode
   nk2 = func(exset []point, split int) *kdNode {
       if len(exset) == 0 {
           return nil
       }
       // pivot choosing procedure.  we find median, then find largest
       // index of points with median value.  this satisfies the
       // inequalities of steps 6 and 7 in the algorithm.
       sort.Sort(part{exset, split})
       m := len(exset) / 2
       d := exset[m]
       for m+1 < len(exset) && exset[m+1][split] == d[split] {
           m++
       }
       // next split
       s2 := split + 1
       if s2 == len(d) {
           s2 = 0
       }
       return &kdNode{d, split, nk2(exset[:m], s2), nk2(exset[m+1:], s2)}
   }
   return kdTree{nk2(pts, 0), bounds}

}

// a container type used for sorting. it holds the points to sort and // the dimension to use for the sort key. type part struct {

   pts   []point
   dPart int

}

// satisfy sort.Interface func (p part) Len() int { return len(p.pts) } func (p part) Less(i, j int) bool {

   return p.pts[i][p.dPart] < p.pts[j][p.dPart]

} func (p part) Swap(i, j int) { p.pts[i], p.pts[j] = p.pts[j], p.pts[i] }

// nearest. find nearest neighbor. return values are: // nearest neighbor--the point within the tree that is nearest p. // square of the distance to that point. // a count of the nodes visited in the search. func (t kdTree) nearest(p point) (best point, bestSqd float64, nv int) {

   return nn(t.n, p, t.bounds, math.Inf(1))

}

// algorithm is table 6.4 from the paper, with the addition of counting // the number nodes visited. func nn(kd *kdNode, target point, hr hyperRect,

   maxDistSqd float64) (nearest point, distSqd float64, nodesVisited int) {
   if kd == nil {
       return nil, math.Inf(1), 0
   }
   nodesVisited++
   s := kd.split
   pivot := kd.domElt
   leftHr := hr.copy()
   rightHr := hr.copy()
   leftHr.max[s] = pivot[s]
   rightHr.min[s] = pivot[s]
   targetInLeft := target[s] <= pivot[s]
   var nearerKd, furtherKd *kdNode
   var nearerHr, furtherHr hyperRect
   if targetInLeft {
       nearerKd, nearerHr = kd.left, leftHr
       furtherKd, furtherHr = kd.right, rightHr
   } else {
       nearerKd, nearerHr = kd.right, rightHr
       furtherKd, furtherHr = kd.left, leftHr
   }
   var nv int
   nearest, distSqd, nv = nn(nearerKd, target, nearerHr, maxDistSqd)
   nodesVisited += nv
   if distSqd < maxDistSqd {
       maxDistSqd = distSqd
   }
   d := pivot[s] - target[s]
   d *= d
   if d > maxDistSqd {
       return
   }
   if d = pivot.sqd(target); d < distSqd {
       nearest = pivot
       distSqd = d
       maxDistSqd = distSqd
   }
   tempNearest, tempSqd, nv := nn(furtherKd, target, furtherHr, maxDistSqd)
   nodesVisited += nv
   if tempSqd < distSqd {
       nearest = tempNearest
       distSqd = tempSqd
   }
   return

}

func main() {

   rand.Seed(time.Now().Unix())
   kd := newKd([]point{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}},
       hyperRect{point{0, 0}, point{10, 10}})
   showNearest("WP example data", kd, point{9, 2})
   kd = newKd(randomPts(3, 1000), hyperRect{point{0, 0, 0}, point{1, 1, 1}})
   showNearest("1000 random 3d points", kd, randomPt(3))

}

func randomPt(dim int) point {

   p := make(point, dim)
   for d := range p {
       p[d] = rand.Float64()
   }
   return p

}

func randomPts(dim, n int) []point {

   p := make([]point, n)
   for i := range p {
       p[i] = randomPt(dim) 
   } 
   return p

}

func showNearest(heading string, kd kdTree, p point) {

   fmt.Println()
   fmt.Println(heading)
   fmt.Println("point:           ", p)
   nn, ssq, nv := kd.nearest(p)
   fmt.Println("nearest neighbor:", nn)
   fmt.Println("distance:        ", math.Sqrt(ssq))
   fmt.Println("nodes visited:   ", nv)

}</lang>

Output:
WP example data
point:            [9 2]
nearest neighbor: [8 1]
distance:         1.4142135623730951
nodes visited:    3

1000 random 3d points
point:            [0.314731890562714 0.5908890147906868 0.2657722255021785]
nearest neighbor: [0.2541611609533609 0.5781168738628141 0.27829000365095274]
distance:         0.06315564613771865
nodes visited:    25

J

As a general rule, tree algorithms are a bad idea in J. That said, here's an implementation:

<lang J>coclass 'kdnode'

 create=:3 :0
   Axis=: ({:$y)|<.2^.#y
   if. 4>#y do.
     Leaf=:1
     Points=: y
   else.
     Leaf=:0
     data=. y /: Axis|."1 y
     n=. >.-:#data
     Points=: ,:n{data
     Left=: conew&'kdnode' n{.data
     Right=: conew&'kdnode' (1+n)}.data
   end.
 )
 distance=: +/&.:*:@:-"1
 nearest=:3 :0
   _ 0 nearest y
 :
   N_base_=:N_base_+1
   dists=. Points distance y
   ndx=. (i. <./) dists
   nearest=. ndx { Points
   range=. ndx { dists
   if. Leaf do.
     range;nearest return.
   else.
     d0=. x <. range
     p0=. nearest
     if. d0=0 do. 0;y return. end.
     if. 0={./:Axis|."1 y,Points do.
       'dist pnt'=.x nearest__Left y
       if. dist > d0 do. d0;p0 return. end.
       if. dist < x do.
         'dist2 pnt2'=. x nearest__Right y
         if. dist2 < dist do. dist2;pnt2 return. end.
       end.
     else.
       'dist pnt'=. x nearest__Right y
       if. dist > d0 do. d0;p0 return. end.
       if. dist < x do.
         'dist2 pnt2'=. x nearest__Left y
         if. dist2 < dist do. dist2;pnt2 return. end.
       end.
     end.
   end.
   dist;pnt return.
 )

coclass 'kdtree'

 create=:3 :0
   root=: conew&'kdnode' y
 )
 nearest=:3 :0
   N_base_=:0
   'dist point'=. nearest__root y
   dist;N_base_;point
 )</lang>

And here's example use:

<lang J> tree=:conew&'kdtree' (2,3), (5,4), (9,6), (4,7), (8,1),: (7,2)

  nearest__tree 9 2

┌───────┬─┬───┐ │1.41421│3│8 1│ └───────┴─┴───┘</lang>

The first box is distance from argument point to selected point. The second box is the number of nodes visited. The third box is the selected point.

Here's the bigger problem:

<lang J> tree=:conew&'kdtree' dataset=:?1000 3$0

  nearest__tree pnt

┌─────────┬───┬──────────────────────────┐ │0.0387914│561│0.978082 0.767632 0.392523│ └─────────┴───┴──────────────────────────┘</lang>

So, why are trees "generally a bad idea in J"?

First off, that's a lot of code, it took time to write. Let's assume that that time was free. Let's also assume that the time taken to build the tree structure was free. We're going to use this tree billions of times. Now what?

Well, let's compare the above implementation to a brute force implementation for time. Here's a "visit all nodes" implementation. It should give us the same kinds of results but we will claim that each candidate point is a node so we'll be visiting a lot more "nodes":

<lang J>build0=:3 :0

 data=: y

)

distance=: +/&.:*:@:-"1

nearest0=:3 :0

 nearest=. data {~ (i. <./) |data distance y
 (nearest distance y);(#data);nearest

)</lang>

Here's the numbers we get:

<lang J> build0 (2,3), (5,4), (9,6), (4,7), (8,1),: (7,2)

  nearest0 9 2

┌───────┬─┬───┐ │1.41421│6│8 1│ └───────┴─┴───┘

  build0 dataset
  nearest0 pnt

┌─────────┬────┬──────────────────────────┐ │0.0387914│1000│0.978082 0.767632 0.392523│ └─────────┴────┴──────────────────────────┘</lang>

But what about timing?

<lang J> tree=:conew&'kdtree' (2,3), (5,4), (9,6), (4,7), (8,1),: (7,2)

  timespacex 'nearest__tree 9 2'

0.000262674 15616

  build0 (2,3), (5,4), (9,6), (4,7), (8,1),: (7,2)
  timespacex 'nearest0 9 2'

3.62419e_5 6016</lang>

The kdtree implementation is almost ten times slower than the brute force implementation for this small dataset. How about the bigger dataset?

<lang J> tree=:conew&'kdtree' dataset 0.0169044 48128

  build0 dataset
  timespacex 'nearest0 pnt'

0.00140702 22144</lang>

On the bigger dataset, the kdtree implementation is over ten times slower than the brute force implementation.

Exercise for the student: work out the cost of cpu time and decide how big of a dataset you need for the time it takes to implement the kdtree to pay back the investment of time needed to implement it. Also work out how many times the tree has to be accessed to more than cover the time needed to build it.

See also: wp:KISS_principle

Perl 6

Translation of: Python

<lang perl6>class Kd_node {

   has $.d;
   has $.split;
   has $.left;
   has $.right;

}

class Orthotope {

   has $.min;
   has $.max;

}

class Kd_tree {

   has $.n;
   has $.bounds;
   method new($pts, $bounds) { self.bless(n => nk2(0,$pts), bounds => $bounds) }
   sub nk2($split, @e) {
       return () unless @e;
       my @exset = @e.sort(*.[$split]);
       my $m = +@exset div 2;
       my @d = @exset[$m][];
       while $m+1 < @exset and @exset[$m+1][$split] eqv @d[$split] {
           ++$m;
       }
       my $s2 = ($split + 1) % @d; # cycle coordinates
       Kd_node.new: :@d, :$split,
               left  => nk2($s2, @exset[0 ..^ $m]),
               right => nk2($s2, @exset[$m ^.. *]);
   }

}

class T3 {

   has $.nearest;
   has $.dist_sqd = Inf;
   has $.nodes_visited = 0;

}

sub find_nearest($k, $t, @p) {

   return nn($t.n, @p, $t.bounds, Inf);
   sub nn($kd, @target, $hr, $max_dist_sqd is copy) {
       return T3.new(:nearest([0.0 xx $k])) unless $kd;
       my $nodes_visited = 1;
       my $s = $kd.split;
       my $pivot = $kd.d;
       my $left_hr = $hr.clone;
       my $right_hr = $hr.clone;
       $left_hr.max[$s] = $pivot[$s];
       $right_hr.min[$s] = $pivot[$s];
       my $nearer_kd;
       my $further_kd;
       my $nearer_hr;
       my $further_hr;
       if @target[$s] <= $pivot[$s] {
           ($nearer_kd, $nearer_hr) = $kd.left, $left_hr;
           ($further_kd, $further_hr) = $kd.right, $right_hr;
       }
       else {
           ($nearer_kd, $nearer_hr) = $kd.right, $right_hr;
           ($further_kd, $further_hr) = $kd.left, $left_hr;
       }

       my $n1 = nn($nearer_kd, @target, $nearer_hr, $max_dist_sqd);
       my $nearest = $n1.nearest;
       my $dist_sqd = $n1.dist_sqd;
       $nodes_visited += $n1.nodes_visited;
       if $dist_sqd < $max_dist_sqd {
           $max_dist_sqd = $dist_sqd;
       }
       my $d = ($pivot[$s] - @target[$s]) ** 2;
       if $d > $max_dist_sqd {
           return T3.new(:$nearest, :$dist_sqd, :$nodes_visited);
       }
       $d = [+] (@$pivot Z- @target) X** 2;
       if $d < $dist_sqd {
           $nearest = $pivot;
           $dist_sqd = $d;
           $max_dist_sqd = $dist_sqd;
       }
       my $n2 = nn($further_kd, @target, $further_hr, $max_dist_sqd);
       $nodes_visited += $n2.nodes_visited;
       if $n2.dist_sqd < $dist_sqd {
           $nearest = $n2.nearest;
           $dist_sqd = $n2.dist_sqd;
       }
       T3.new(:$nearest, :$dist_sqd, :$nodes_visited);
   }

}

sub show_nearest($k, $heading, $kd, @p) {

   print qq:to/END/;
       $heading:
       Point:            [@p.join(',')]
       END
   my $n = find_nearest($k, $kd, @p);
   print qq:to/END/;
       Nearest neighbor: [$n.nearest.join(',')]
       Distance:         &sqrt($n.dist_sqd)
       Nodes visited:    $n.nodes_visited()
       
       END

}

sub random_point($k) { [rand xx $k] } sub random_points($k, $n) { [random_point($k) xx $n] }

sub MAIN {

   my $kd1 = Kd_tree.new([[2, 3],[5, 4],[9, 6],[4, 7],[8, 1],[7, 2]],
                 Orthotope.new(:min([0, 0]), :max([10, 10])));
   show_nearest(2, "Wikipedia example data", $kd1, [9, 2]);
   my $N = 1000;
   my $t0 = now;
   my $kd2 = Kd_tree.new(random_points(3, $N), Orthotope.new(:min([0,0,0]), :max([1,1,1])));
   my $t1 = now;
   show_nearest(2,
       "k-d tree with $N random 3D points (generation time: {$t1 - $t0}s)",
        $kd2, random_point(3));

}</lang>

Output:
Wikipedia example data:
Point:            [9,2]
Nearest neighbor: [8,1]
Distance:         1.4142135623731
Nodes visited:    3

k-d tree with 1000 random 3D points (generation time: 67.0934954s):
Point:            [0.765565651400664,0.223251226280109,0.00536717765240979]
Nearest neighbor: [0.758919336088656,0.228895111242011,0.0383284709862686]
Distance:         0.0340950700678338
Nodes visited:    23

Python

Translation of: D

<lang python>from random import seed, random from time import clock from operator import itemgetter from collections import namedtuple from math import sqrt from copy import deepcopy


def sqd(p1, p2):

   return sum((c1 - c2) ** 2 for c1, c2 in zip(p1, p2))


class KdNode(object):

   __slots__ = ("dom_elt", "split", "left", "right")
   def __init__(self, dom_elt, split, left, right):
       self.dom_elt = dom_elt
       self.split = split
       self.left = left
       self.right = right


class Orthotope(object):

   __slots__ = ("min", "max")
   def __init__(self, mi, ma):
       self.min, self.max = mi, ma


class KdTree(object):

   __slots__ = ("n", "bounds")
   def __init__(self, pts, bounds):
       def nk2(split, exset):
           if not exset:
               return None
           exset.sort(key=itemgetter(split))
           m = len(exset) // 2
           d = exset[m]
           while m + 1 < len(exset) and exset[m + 1][split] == d[split]:
               m += 1
           s2 = (split + 1) % len(d)  # cycle coordinates
           return KdNode(d, split, nk2(s2, exset[:m]),
                                   nk2(s2, exset[m + 1:]))
       self.n = nk2(0, pts)
       self.bounds = bounds

T3 = namedtuple("T3", "nearest dist_sqd nodes_visited")


def find_nearest(k, t, p):

   def nn(kd, target, hr, max_dist_sqd):
       if kd is None:
           return T3([0.0] * k, float("inf"), 0)
       nodes_visited = 1
       s = kd.split
       pivot = kd.dom_elt
       left_hr = deepcopy(hr)
       right_hr = deepcopy(hr)
       left_hr.max[s] = pivot[s]
       right_hr.min[s] = pivot[s]
       if target[s] <= pivot[s]:
           nearer_kd, nearer_hr = kd.left, left_hr
           further_kd, further_hr = kd.right, right_hr
       else:
           nearer_kd, nearer_hr = kd.right, right_hr
           further_kd, further_hr = kd.left, left_hr
       n1 = nn(nearer_kd, target, nearer_hr, max_dist_sqd)
       nearest = n1.nearest
       dist_sqd = n1.dist_sqd
       nodes_visited += n1.nodes_visited
       if dist_sqd < max_dist_sqd:
           max_dist_sqd = dist_sqd
       d = (pivot[s] - target[s]) ** 2
       if d > max_dist_sqd:
           return T3(nearest, dist_sqd, nodes_visited)
       d = sqd(pivot, target)
       if d < dist_sqd:
           nearest = pivot
           dist_sqd = d
           max_dist_sqd = dist_sqd
       n2 = nn(further_kd, target, further_hr, max_dist_sqd)
       nodes_visited += n2.nodes_visited
       if n2.dist_sqd < dist_sqd:
           nearest = n2.nearest
           dist_sqd = n2.dist_sqd
       return T3(nearest, dist_sqd, nodes_visited)
   return nn(t.n, p, t.bounds, float("inf"))


def show_nearest(k, heading, kd, p):

   print(heading + ":")
   print("Point:           ", p)
   n = find_nearest(k, kd, p)
   print("Nearest neighbor:", n.nearest)
   print("Distance:        ", sqrt(n.dist_sqd))
   print("Nodes visited:   ", n.nodes_visited, "\n")


def random_point(k):

   return [random() for _ in range(k)]


def random_points(k, n):

   return [random_point(k) for _ in range(n)]

if __name__ == "__main__":

   seed(1)
   P = lambda *coords: list(coords)
   kd1 = KdTree([P(2, 3), P(5, 4), P(9, 6), P(4, 7), P(8, 1), P(7, 2)],
                 Orthotope(P(0, 0), P(10, 10)))
   show_nearest(2, "Wikipedia example data", kd1, P(9, 2))
   N = 400000
   t0 = clock()
   kd2 = KdTree(random_points(3, N), Orthotope(P(0, 0, 0), P(1, 1, 1)))
   t1 = clock()
   text = lambda *parts: "".join(map(str, parts))
   show_nearest(2, text("k-d tree with ", N,
                        " random 3D points (generation time: ",
                        t1-t0, "s)"),
                kd2, random_point(3))</lang>
Output:
Wikipedia example data:
Point:            [9, 2]
Nearest neighbor: [8, 1]
Distance:         1.41421356237
Nodes visited:    3

k-d tree with 400000 random 3D points (generation time: 14.8755565302s):
Point:            [0.066694022911324868, 0.13692213852082813, 0.94939167224227283]
Nearest neighbor: [0.067027753280507252, 0.14407354836507069, 0.94543775920177597]
Distance:         0.00817847583914
Nodes visited:    33

Racket

The following code is optimized for readability. <lang racket>

  1. lang racket
A tree consists of a point, a left and a right subtree.

(struct tree (p l r) #:transparent)

If the node is in depth d, then the points in l has
the (d mod k)'th coordinate less than the same coordinate in p.

(define (kdtree d k ps)

 (cond [(empty? ps) #f] ; #f represents an empty subtree
       [else (define-values (p l r) (split-points ps (modulo d k)))
             (tree p (kdtree (+ d 1) k l) (kdtree (+ d 1) k r))]))

(define (split-points ps d)

 (define (ref p) (vector-ref p d))
 (define sorted-ps (sort ps < #:key ref))
 (define mid (quotient (+ (length ps)) 2))
 (define median (ref (list-ref sorted-ps mid)))
 (define-values (l r) (partition(λ(x)(< (ref x) median))sorted-ps))
 (values (first r) l (rest r)))
The bounding box of a subtree

(struct bb (mins maxs) #:transparent)

(define (infinite-bb k)

 (bb (make-vector k -inf.0) (make-vector k +inf.0)))

(define/match (copy-bb h)

 [((bb mins maxs)) 
  (bb (vector-copy mins) (vector-copy maxs))])

(define (dist v w) (for/sum ([x v] [y w]) (sqr (- x y)))) (define (intersects? g r hr) (<= (dist (closest-in-hr g hr) g) r)) (define (closest-in-hr g hr)

 (for/vector ([gi g] [mini (bb-mins hr)] [maxi (bb-maxs hr)])
   (cond [(<=     gi mini) mini]
         [(< mini gi maxi) gi]
         [else             maxi])))

(define (split-bb hr d x)

 (define left  (copy-bb hr))
 (define right (copy-bb hr))
 (vector-set! (bb-maxs left) d x)
 (vector-set! (bb-mins right) d x)
 (values left right))

(define visits 0) ; for statistics only (define (visit) (set! visits (+ visits 1))) (define (reset-visits) (set! visits 0)) (define (regret-visit) (set! visits (- visits 1)))

(define (nearest-neighbor g t k)

 (define (nearer? p q) (< (dist p g) (dist q g)))
 (define (nearest p q) (if (nearer? p q) p q))
 (define (nn d t bb) (visit)
   (define (ref p) (vector-ref p (modulo d k)))
   (match t
     [#f (regret-visit) #(+inf.0 +inf.0 +inf.0)]
     [(tree p l r)
      (define-values (lbb rbb) (split-bb bb (modulo d k) (ref p)))
      (define-values (near near-bb far far-bb)
        (if (< (ref g) (ref p))
            (values l lbb r rbb)
            (values r rbb l lbb)))
      (define n (nearest p (nn (+ d 1) near near-bb)))
      (if (intersects? g (dist n g) far-bb)
          (nearest n (nn (+ d 1) far far-bb))
          n)]))
 (nn 0 t (infinite-bb k)))

</lang> Tests: <lang racket> (define (wikipedia-test)

 (define t (kdtree 0 2 '(#(2 3) #(5 4) #(9 6) #(4 7) #(8 1) #(7 2))))
 (reset-visits)
 (define n (nearest-neighbor #(9 2) t 2))
 (displayln "Wikipedia Test")
 (displayln (~a "Nearest neighbour to (9,2) is: " n))
 (displayln (~a "Distance: " (dist n #(9 2))))
 (displayln (~a "Visits: " visits "\n")))

(define (test k n)

 (define (random!) (for/vector ([_ k]) (random)))
 (define points (for/list ([_ n]) (random!)))
 (define t (kdtree 0 k points))
 (reset-visits)
 (define target (for/vector ([_ k]) 0.75))
 (define nb (nearest-neighbor target t k))
 (define nb-control (argmin (λ (p) (dist p target)) points))
 (displayln (~a n " points in R^3 test"))
 (displayln (~a "Nearest neighbour to " target " is: \n\t\t" nb))
 (displayln (~a "Control: \t" nb-control))
 (displayln (~a "Distance: \t" (dist nb target)))
 (displayln (~a "Control: \t" (dist nb-control target)))
 (displayln (~a "Visits: \t" visits)))

(wikipedia-test) (test 3 1000) (test 3 1000) </lang> Output: <lang racket> Wikipedia Test Nearest neighbour to (9,2) is: #(8 1) Distance: 2 Visits: 3

1000 points in R^3 test Nearest neighbour to #(0.75 0.75 0.75) is: #(0.8092534479975508 0.7507095851813429 0.7706494651024903) Control: #(0.8092534479975508 0.7507095851813429 0.7706494651024903) Distance: 0.003937875019747008 Control: 0.003937875019747008 Visits: 83

1000 points in R^3 test Nearest neighbour to #(0.75 0.75 0.75) is: #(0.7775581478448806 0.7806612633582072 0.7396664367640902) Control: #(0.7775581478448806 0.7806612633582072 0.7396664367640902) Distance: 0.0018063471125121851 Control: 0.0018063471125121851 Visits: 39 </lang>

Tcl

Translation of: Python
Library: TclOO

<lang tcl>package require TclOO

oo::class create KDTree {

   variable t dim
   constructor {points} {

set t [my Build 0 $points 0 end] set dim [llength [lindex $points 0]]

   }
   method Build {split exset from to} {

set exset [lsort -index $split -real [lrange $exset $from $to]] if {![llength $exset]} {return 0} set m [expr {[llength $exset] / 2}] set d [lindex $exset $m] while {[set mm $m;incr mm] < [llength $exset] && \ [lindex $exset $mm $split] == [lindex $d $split]} { set m $mm } set s [expr {($split + 1) % [llength $d]}] list 1 $d $split [my Build $s $exset 0 [expr {$m-1}]] \ [my Build $s $exset [expr {$m+1}] end]

   }
   method findNearest {p} {

lassign [my FN $t $p inf] p d2 count return [list $p [expr {sqrt($d2)}] $count]

   }
   method FN {kd target maxDist2} {

if {[lindex $kd 0] == 0} { return [list [lrepeat $dim 0.0] inf 0] }

set nodesVisited 1 lassign $kd -> pivot s

if {[lindex $target $s] <= [lindex $pivot $s]} { set nearerKD [lindex $kd 3] set furtherKD [lindex $kd 4] } else { set nearerKD [lindex $kd 4] set furtherKD [lindex $kd 3] }

lassign [my FN $nearerKD $target $maxDist2] nearest dist2 count incr nodesVisited $count

if {$dist2 < $maxDist2} { set maxDist2 $dist2 } set d2 [expr {([lindex $pivot $s]-[lindex $target $s])**2}] if {$d2 > $maxDist2} { return [list $nearest $dist2 $nodesVisited] } set d2 0.0 foreach pp $pivot tp $target {set d2 [expr {$d2+($pp-$tp)**2}]} if {$d2 < $dist2} { set nearest $pivot set maxDist2 [set dist2 $d2] }

lassign [my FN $furtherKD $target $maxDist2] fNearest fDist2 count incr nodesVisited $count if {$fDist2 < $dist2} { set nearest $fNearest set dist2 $fDist2 }

return [list $nearest $dist2 $nodesVisited]

   }

}</lang> Demonstration code: <lang tcl>proc showNearest {heading tree point} {

   puts ${heading}:
   puts "Point:            \[[join $point ,]\]"
   lassign [$tree findNearest $point] nearest distance count
   puts "Nearest neighbor: \[[join $nearest ,]\]"
   puts "Distance:         $distance"
   puts "Nodes visited:    $count"

} proc randomPoint k {

   for {set j 0} {$j < $k} {incr j} {lappend p [::tcl::mathfunc::rand]}
   return $p

} proc randomPoints {k n} {

   for {set i 0} {$i < $n} {incr i} {

set p {} for {set j 0} {$j < $k} {incr j} { lappend p [::tcl::mathfunc::rand] } lappend ps $p

   }
   return $ps

}

KDTree create kd1 {{2 3} {5 4} {9 6} {4 7} {8 1} {7 2}} showNearest "Wikipedia example data" kd1 {9 2} puts ""

set N 1000 set t [time {KDTree create kd2 [randomPoints 3 $N]}] showNearest "k-d tree with $N random 3D points (generation time: [lindex $t 0] us)" kd2 [randomPoint 3] kd2 destroy puts ""

set N 1000000 set t [time {KDTree create kd2 [randomPoints 3 $N]}] showNearest "k-d tree with $N random 3D points (generation time: [lindex $t 0] us)" kd2 [randomPoint 3] puts "Search time: [time {kd2 findNearest [randomPoint 3]} 10000]"</lang>

Output:
Wikipedia example data:
Point:            [9,2]
Nearest neighbor: [8,1]
Distance:         1.4142135623730951
Nodes visited:    3

k-d tree with 1000 random 3D points (generation time: 11908 us):
Point:            [0.8480196329057308,0.6659702466176685,0.961934903153188]
Nearest neighbor: [0.8774737389187672,0.7011300077201472,0.8920397525150514]
Distance:         0.0836007490668497
Nodes visited:    29

k-d tree with 1000000 random 3D points (generation time: 19643366 us):
Point:            [0.10923849936073576,0.9714587558859301,0.30731017482807405]
Nearest neighbor: [0.10596616664247875,0.9733627601402638,0.3079096774141815]
Distance:         0.0038331184393709545
Nodes visited:    22
Search time:      289.894755 microseconds per iteration