K-d tree
This page uses content from Wikipedia. The original article was at K-d tree. The list of authors can be seen in the page history. As with Rosetta Code, the text of Wikipedia is available under the GNU FDL. (See links for details on variance) |
A k-d tree (short for k-dimensional tree) is a space-partitioning data structure for organizing points in a k-dimensional space. k-d trees are a useful data structure for several applications, such as searches involving a multidimensional search key (e.g. range searches and nearest neighbor searches). k-d trees are a special case of binary space partitioning trees.
k-d trees are not suitable, however, for efficiently finding the nearest neighbor in high dimensional spaces. As a general rule, if the dimensionality is k, the number of points in the data, N, should be N ≫ 2k. Otherwise, when k-d trees are used with high-dimensional data, most of the points in the tree will be evaluated and the efficiency is no better than exhaustive search, and other methods such as approximate nearest-neighbor are used instead.
Task: Construct a k-d tree and perform a nearest neighbor search for two example data sets:
- The Wikipedia example data of [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)].
- 1000 3-d points uniformly distributed in a 3-d cube.
For the Wikipedia example, find the nearest neighbor to point (9, 2) For the random data, pick a random location and find the nearest neighbor.
In addition, instrument your code to count the number of nodes visited in the nearest neighbor search. Count a node as visited if any field of it is accessed.
Output should show the point searched for, the point found, the distance to the point, and the number of nodes visited.
There are variant algorithms for constructing the tree. You can use a simple median strategy or implement something more efficient. Variants of the nearest neighbor search include nearest N neighbors, approximate nearest neighbor, and range searches. You do not have to implement these. The requirement for this task is specifically the nearest single neighbor. Also there are algorithms for inserting, deleting, and balancing k-d trees. These are also not required for the task.
D
Points are values, the code is templated on the the dimensionality of the points and the floating point type of the coordinate. Instead of sorting it uses the faster topN, that partitions the points array in two halves around their median. <lang d>// Implmentation following pseudocode from // "An intoductory tutorial on kd-trees" by Andrew W. Moore, // Carnegie Mellon University, PDF accessed from: // http://www.autonlab.org/autonweb/14665
import std.typecons, std.math, std.algorithm, std.random, std.range,
std.traits;
/// k-dimensional point. struct Point(size_t k, F) if (isFloatingPoint!F) {
F[k] data; alias data this;
/// Square of the euclidean distance. double sqd(const ref Point!(k,F) q) pure nothrow { double sum = 0; foreach (dim, pCoord; data) sum += (pCoord - q[dim]) ^^ 2; return sum; }
}
// Following field names in the paper. // rangeElt would be whatever data is associated with the Point. // We don't bother with it for this example. struct KdNode(size_t k, F) {
Point!(k,F) domElt; immutable int split; KdNode!(k,F)* left, right;
// This naive ctor is currently necessary this(in Point!(k,F) domElt_, in int split_, KdNode!(k,F)* left_, KdNode!(k,F)* right_) pure nothrow { this.domElt = domElt_; this.split = split_; this.left = left_; this.right = right_; }
}
struct Orthotope(size_t k, F) { /// k-dimensional rectangle.
Point!(k,F) min, max;
}
struct KdTree(size_t k, F) {
KdNode!(k,F)* n; Orthotope!(k,F) bounds;
// Constructs a KdTree from a list of points, also associating the // bounds of the tree. The bounds could be computed of course, but // in this example we know them already. The algorithm is table // 6.3 in the paper. this(Point!(k,F)[] pts, in Orthotope!(k,F) bounds_) { static KdNode!(k,F)* nk2(size_t split)(Point!(k,F)[] exset) { if (exset.empty) return null; // Pivot choosing procedure. We find median, then find // largest index of points with median value. This // satisfies the inequalities of steps 6 and 7 in the // algorithm. auto m = exset.length / 2; //topN!((p, q) => p[split] < q[split])(exset, m); topN!((Point!(k,F) p, Point!(k,F) q)=> p[split] < q[split]) (exset, m); immutable d = exset[m]; while (m+1 < exset.length && exset[m+1][split] == d[split]) m++;
enum nextSplit = (split + 1) % d.length;//cycle coordinates return new KdNode!(k,F)(d, split, nk2!nextSplit(exset[0 .. m]), nk2!nextSplit(exset[m + 1 .. $])); } this.n = nk2!0(pts); this.bounds = bounds_; }
}
/** Find nearest neighbor. Return values are:
nearest neighbor--the ooint within the tree that is nearest p. square of the distance to that point. a count of the nodes visited in the search.
- /
auto findNearest(size_t k, F)(KdTree!(k,F) t, in Point!(k,F) p) pure nothrow {
// Algorithm is table 6.4 from the paper, with the addition of // counting the number nodes visited. static Tuple!(Point!(k,F),"nearest", double,"distSqd", int,"nodesVisited") nn(KdNode!(k,F)* kd, in Point!(k,F) target, Orthotope!(k,F) hr, double maxDistSqd) pure nothrow { if (kd == null) return typeof(return)(Point!(k,F)(), double.infinity, 0);
int nodesVisited = 1; immutable s = kd.split; auto pivot = kd.domElt; auto leftHr = hr; auto rightHr = hr; leftHr.max[s] = pivot[s]; rightHr.min[s] = pivot[s];
KdNode!(k,F)* nearerKd, furtherKd; Orthotope!(k,F) nearerHr, furtherHr; if (target[s] <= pivot[s]) { //nearerKd, nearerHr = kd.left, leftHr //furtherKd, furtherHr = kd.right, rightHr nearerKd = kd.left; nearerHr = leftHr; furtherKd = kd.right; furtherHr = rightHr; } else { //nearerKd, nearerHr = kd.right, rightHr //furtherKd, furtherHr = kd.left, leftHr nearerKd = kd.right; nearerHr = rightHr; furtherKd = kd.left; furtherHr = leftHr; }
auto n1 = nn(nearerKd, target, nearerHr, maxDistSqd); auto nearest = n1.nearest; auto distSqd = n1.distSqd; nodesVisited += n1.nodesVisited;
if (distSqd < maxDistSqd) maxDistSqd = distSqd; auto d = (pivot[s] - target[s]) ^^ 2; if (d > maxDistSqd) return typeof(return)(nearest, distSqd, nodesVisited); d = pivot.sqd(target); if (d < distSqd) { nearest = pivot; distSqd = d; maxDistSqd = distSqd; }
immutable n2 = nn(furtherKd, target, furtherHr, maxDistSqd); nodesVisited += n2.nodesVisited; if (n2.distSqd < distSqd) { nearest = n2.nearest; distSqd = n2.distSqd; }
return typeof(return)(nearest, distSqd, nodesVisited); }
return nn(t.n, p, t.bounds, double.infinity);
}
void showNearest(size_t k, F)(in string heading, KdTree!(k,F) kd,
in Point!(k,F) p) { import std.stdio; writeln(heading, ":"); writeln("Point: ", p); immutable n = kd.findNearest(p); writeln("Nearest neighbor: ", n.nearest); writeln("Distance: ", sqrt(n.distSqd)); writeln("Nodes visited: ", n.nodesVisited, "\n");
}
Point!(k,F) randomPoint(size_t k, F)() {
typeof(return) result; foreach (i; 0 .. k) result[i] = uniform(cast(F)0, cast(F)1); return result;
}
Point!(k,F)[] randomPoints(size_t k, F)(int n) {
return iota(n).map!(_ => randomPoint!(k,F)())().array();
}
void main() {
import std.conv, std.datetime, std.typetuple; rndGen.seed(1); // For repeatable outputs.
alias TypeTuple!(2, double) D2; alias Point!D2 P; auto kd1 = KdTree!D2([P([2, 3]), P([5, 4]), P([9, 6]), P([4, 7]), P([8, 1]), P([7, 2])], Orthotope!D2(P([0, 0]), P([10, 10]))); showNearest("Wikipedia example data", kd1, P([9, 2]));
enum int N = 400_000; alias TypeTuple!(3, float) F3; alias Point!F3 Q; StopWatch sw; sw.start(); auto kd2 = KdTree!F3(randomPoints!F3(N), Orthotope!F3(Q([0, 0, 0]), Q([1, 1, 1]))); sw.stop(); showNearest(text("k-d tree with ", N, " random 3D points (construction time: ", sw.peek().msecs, "ms)"), kd2, randomPoint!F3());
}</lang>
- Output:
Wikipedia example data: Point: [9, 2] Nearest neighbor: [8, 1] Distance: 1.41421 Nodes visited: 3 k-d tree with 400000 random 3D points (construction time: 1011ms): Point: [0.22012, 0.984514, 0.698782] Nearest neighbor: [0.225766, 0.978981, 0.69885] Distance: 0.00790531 Nodes visited: 54
Go
<lang go>// Implmentation following pseudocode from "An intoductory tutorial on kd-trees" // by Andrew W. Moore, Carnegie Mellon University, PDF accessed from // http://www.autonlab.org/autonweb/14665 package main
import (
"fmt" "math" "math/rand" "sort" "time"
)
// point is a k-dimensional point. type point []float64
// sqd returns the square of the euclidean distance. func (p point) sqd(q point) float64 {
var sum float64 for dim, pCoord := range p { d := pCoord - q[dim] sum += d * d } return sum
}
// kdNode following field names in the paper. // rangeElt would be whatever data is associated with the point. we don't // bother with it for this example. type kdNode struct {
domElt point split int left, right *kdNode
}
type kdTree struct {
n *kdNode bounds hyperRect
}
type hyperRect struct {
min, max point
}
// Go slices are reference objects. The data must be copied if you want // to modify one without modifying the original. func (hr hyperRect) copy() hyperRect {
return hyperRect{append(point{}, hr.min...), append(point{}, hr.max...)}
}
// newKd constructs a kdTree from a list of points, also associating the // bounds of the tree. The bounds could be computed of course, but in this // example we know them already. The algorithm is table 6.3 in the paper. func newKd(pts []point, bounds hyperRect) kdTree {
var nk2 func([]point, int) *kdNode nk2 = func(exset []point, split int) *kdNode { if len(exset) == 0 { return nil } // pivot choosing procedure. we find median, then find largest // index of points with median value. this satisfies the // inequalities of steps 6 and 7 in the algorithm. sort.Sort(part{exset, split}) m := len(exset) / 2 d := exset[m] for m+1 < len(exset) && exset[m+1][split] == d[split] { m++ } // next split s2 := split + 1 if s2 == len(d) { s2 = 0 } return &kdNode{d, split, nk2(exset[:m], s2), nk2(exset[m+1:], s2)} } return kdTree{nk2(pts, 0), bounds}
}
// a container type used for sorting. it holds the points to sort and // the dimension to use for the sort key. type part struct {
pts []point dPart int
}
// satisfy sort.Interface func (p part) Len() int { return len(p.pts) } func (p part) Less(i, j int) bool {
return p.pts[i][p.dPart] < p.pts[j][p.dPart]
} func (p part) Swap(i, j int) { p.pts[i], p.pts[j] = p.pts[j], p.pts[i] }
// nearest. find nearest neighbor. return values are: // nearest neighbor--the point within the tree that is nearest p. // square of the distance to that point. // a count of the nodes visited in the search. func (t kdTree) nearest(p point) (best point, bestSqd float64, nv int) {
return nn(t.n, p, t.bounds, math.Inf(1))
}
// algorithm is table 6.4 from the paper, with the addition of counting // the number nodes visited. func nn(kd *kdNode, target point, hr hyperRect,
maxDistSqd float64) (nearest point, distSqd float64, nodesVisited int) { if kd == nil { return nil, math.Inf(1), 0 } nodesVisited++ s := kd.split pivot := kd.domElt leftHr := hr.copy() rightHr := hr.copy() leftHr.max[s] = pivot[s] rightHr.min[s] = pivot[s] targetInLeft := target[s] <= pivot[s] var nearerKd, furtherKd *kdNode var nearerHr, furtherHr hyperRect if targetInLeft { nearerKd, nearerHr = kd.left, leftHr furtherKd, furtherHr = kd.right, rightHr } else { nearerKd, nearerHr = kd.right, rightHr furtherKd, furtherHr = kd.left, leftHr } var nv int nearest, distSqd, nv = nn(nearerKd, target, nearerHr, maxDistSqd) nodesVisited += nv if distSqd < maxDistSqd { maxDistSqd = distSqd } d := pivot[s] - target[s] d *= d if d > maxDistSqd { return } if d = pivot.sqd(target); d < distSqd { nearest = pivot distSqd = d maxDistSqd = distSqd } tempNearest, tempSqd, nv := nn(furtherKd, target, furtherHr, maxDistSqd) nodesVisited += nv if tempSqd < distSqd { nearest = tempNearest distSqd = tempSqd } return
}
func main() {
rand.Seed(time.Now().Unix()) kd := newKd([]point{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}, hyperRect{point{0, 0}, point{10, 10}}) showNearest("WP example data", kd, point{9, 2}) kd = newKd(randomPts(3, 1000), hyperRect{point{0, 0, 0}, point{1, 1, 1}}) showNearest("1000 random 3d points", kd, randomPt(3))
}
func randomPt(dim int) point {
p := make(point, dim) for d := range p { p[d] = rand.Float64() } return p
}
func randomPts(dim, n int) []point {
p := make([]point, n) for i := range p { p[i] = randomPt(dim) } return p
}
func showNearest(heading string, kd kdTree, p point) {
fmt.Println() fmt.Println(heading) fmt.Println("point: ", p) nn, ssq, nv := kd.nearest(p) fmt.Println("nearest neighbor:", nn) fmt.Println("distance: ", math.Sqrt(ssq)) fmt.Println("nodes visited: ", nv)
}</lang>
- Output:
WP example data point: [9 2] nearest neighbor: [8 1] distance: 1.4142135623730951 nodes visited: 3 1000 random 3d points point: [0.314731890562714 0.5908890147906868 0.2657722255021785] nearest neighbor: [0.2541611609533609 0.5781168738628141 0.27829000365095274] distance: 0.06315564613771865 nodes visited: 25
Python
<lang python>from random import seed, random from time import clock from operator import itemgetter from collections import namedtuple from math import sqrt from copy import deepcopy
def sqd(p1, p2):
return sum((c1 - c2) ** 2 for c1,c2 in zip(p1, p2))
class Kd_node(object):
__slots__ = ["dom_elt", "split", "left", "right"] def __init__(self, dom_elt_, split_, left_, right_): self.dom_elt = dom_elt_ self.split = split_ self.left = left_ self.right = right_
class Orthotope(object):
def __init__(self, mi, ma): self.min, self.max = mi, ma
class Kd_tree(object):
def __init__(self, pts, bounds_): def nk2(split, exset): if not exset: return None exset.sort(key=itemgetter(split)) m = len(exset) // 2 d = exset[m] while m+1 < len(exset) and exset[m+1][split] == d[split]: m += 1
s2 = (split + 1) % len(d) # cycle coordinates return Kd_node(d, split, nk2(s2, exset[:m]), nk2(s2, exset[m+1:])) self.n = nk2(0, pts) self.bounds = bounds_
T3 = namedtuple("T3", "nearest dist_sqd nodes_visited")
def find_nearest(k, t, p):
def nn(kd, target, hr, max_dist_sqd): if kd == None: return T3([0.0] * k, float("inf"), 0)
nodes_visited = 1 s = kd.split pivot = kd.dom_elt left_hr = deepcopy(hr) right_hr = deepcopy(hr) left_hr.max[s] = pivot[s] right_hr.min[s] = pivot[s]
if target[s] <= pivot[s]: nearer_kd, nearer_hr = kd.left, left_hr further_kd, further_hr = kd.right, right_hr else: nearer_kd, nearer_hr = kd.right, right_hr further_kd, further_hr = kd.left, left_hr
n1 = nn(nearer_kd, target, nearer_hr, max_dist_sqd) nearest = n1.nearest dist_sqd = n1.dist_sqd nodes_visited += n1.nodes_visited
if dist_sqd < max_dist_sqd: max_dist_sqd = dist_sqd d = (pivot[s] - target[s]) ** 2 if d > max_dist_sqd: return T3(nearest, dist_sqd, nodes_visited) d = sqd(pivot, target) if d < dist_sqd: nearest = pivot dist_sqd = d max_dist_sqd = dist_sqd
n2 = nn(further_kd, target, further_hr, max_dist_sqd) nodes_visited += n2.nodes_visited if n2.dist_sqd < dist_sqd: nearest = n2.nearest dist_sqd = n2.dist_sqd
return T3(nearest, dist_sqd, nodes_visited)
return nn(t.n, p, t.bounds, float("inf"))
def show_nearest(k, heading, kd, p):
print heading + ":" print "Point: ", p n = find_nearest(k, kd, p) print "Nearest neighbor:", n.nearest print "Distance: ", sqrt(n.dist_sqd) print "Nodes visited: ", n.nodes_visited, "\n"
def random_point(k):
return [random() for _ in xrange(k)]
def random_points(k, n):
return [random_point(k) for _ in xrange(n)]
if __name__ == "__main__":
seed(1) P = lambda *coords: list(coords) kd1 = Kd_tree([P(2, 3),P(5, 4),P(9, 6),P(4, 7),P(8, 1),P(7, 2)], Orthotope(P(0, 0), P(10, 10))) show_nearest(2, "Wikipedia example data", kd1, P(9, 2))
N = 400000 t0 = clock() kd2 = Kd_tree(random_points(3, N), Orthotope(P(0,0,0), P(1,1,1))) t1 = clock() text = lambda *parts: "".join(map(str, parts)) show_nearest(2, text("k-d tree with ", N, " random 3D points (generation time: ", t1-t0, "s)"), kd2, random_point(3))</lang>
- Output:
Wikipedia example data: Point: [9, 2] Nearest neighbor: [8, 1] Distance: 1.41421356237 Nodes visited: 3 k-d tree with 400000 random 3D points (generation time: 14.8755565302s): Point: [0.066694022911324868, 0.13692213852082813, 0.94939167224227283] Nearest neighbor: [0.067027753280507252, 0.14407354836507069, 0.94543775920177597] Distance: 0.00817847583914 Nodes visited: 33
REXX
<lang rexx> /*REXX program to find the nearest point in K-d space to another point. */
numeric digits 100 /*as seen on "The Hitchhiker's */
/*Guide to the Universe", space */ /*is very very big. So big digs.*/
/*─────────────────────────────────────perform the 1st task (out of two)*/ somePoints='[(2,3) (5,4) (9,6) (4,7) (8,1) (7,2)]' call build_Kd_tree somePoints target='(9,2)' call search_Kd_tree target,'point','target' /*find nearest point to targ*/
/*─────────────────────────────────────perform the 2nd task (out of two)*/ stars=; maxp=1e5; #stars=1000 /*use a 100k x 100k x 100k cube.*/
do j=1 for #stars /*gen 1k random stars within cube*/ stars=stars multi_dimensional_space(3,maxp) /*gen a star address.*/ end /*j*/
Astar=multi_dimensional_space(3,maxp) /*gen a single star address ... */ stars='['strip(stars)"]" call build_Kd_tree stars call search_Kd_tree Astar,'star',"target star" /*nearest star to Astar.*/
/*─────────────────────────────────────choose a point in middle of cube.*/ target='('maxp%2 maxp%2 maxp%2")"; target=translate(target,','," ") call search_Kd_tree target,'star',"cube center" /*find centermost star.*/
/*───────────────────────────────find nearest fly in a hypercube to food*/ flies=; maxp=100; #flys=10000 /*use a 100x100x100x100 hypercage*/
do j=1 for #flys /*gen 1k random stars within cube*/ flies=flies multi_dimensional_space(4,maxp) /*gen a fly's address*/ end /*j*/
poop=multi_dimensional_space(4,maxp) /*gen an address of fly's food. */ flies='['strip(flies)"]" call build_Kd_tree flies /*quick, before they find food. */ call search_Kd_tree poop,'fly',"food" /*find the nearest fly to food. */ exit
/*───────────────────────────────subroutines────────────────────────────*/ sqrt: procedure;parse arg x; if x=0 then return 0; d=digits(); numeric digits 11; g=$sqguess()
do j=0 while p>9; m.j=p; p=p%2+1; end; do k=j+5 to 0 by -1; if m.k>11 then numeric digits m.k g=.5*(g+x/g); end; numeric digits d; return g/1
$sqguess: numeric form scientific; m.=11; p=d+d%4+2
parse value format(x,2,1,,0) 'E0' with g 'E' _ .; return g*.5'E'_%2
pure: return space(translate(arg(1),,'][)('arg(2)))
/*───────────────────────────────MULTI_DIMENSIONAL_SPACE subroutine─────*/ multi_dimensional_space: procedure; arg Kdims,mx; do k=1 for Kdims
if k==1 then _='(' else _=_',' _=_||random(0,mx) end /*k*/
return _')'
/*───────────────────────────────BUILD_KD_TREE subroutine───────────────*/ build_Kd_tree: parse arg p,@.; p=pure(p); n=words(p) Lside=0; Rside=0; s=0; dimensions=words(pure(word(pure(p),1),','))
do j=1 for n; _=word(pure(word(p,j),','),1) s=s+_ end /*j*/
avg=s/n
do k=1 for n _=translate(word(p,k),,','); _1=word(_,1) if _1<avg then do; Lside=Lside+1; @.0.Lside=_; end else do; Rside=Rside+1; @.1.Rside=_; end end /*k*/
return
/*───────────────────────────────SEARCH_KD_TREE subroutine──────────────*/ search_Kd_tree: parse arg targ,what,where; tarp=pure(targ,','); md='¬'
do dim=1 for dimensions; tar.dim=word(tarp,dim); end
side=tar.1>=a /*choose which side of node tree to use*/ if side then #nodes=Rside /*using "right" side. */
else #nodes=Lside /*using "left" side. */ do nodes=1 for #nodes; _=0 /*traipse through node.*/ do j=1 to dimensions /*proc. each dimension.*/ pt.j=word(@.side.nodes,j) /*get the Nth dimension*/ _=_+(tar.j-pt.j)**2 /*calculate distance. */ end /*j*/ if md=='¬' then md=_ /*first time through ? */ if _<=md then do; md=_; nn=nodes; end /*found a nearer point.*/ if _==0 then leave /*we're lucky, we hit a spot-on match. */ end /*nodes*/
dist=format(sqrt(md),,4)/1 /*calculate the distance between points*/ neighbor='('translate(@.side.nn,",",' ')")" say left(,11+length(what)) where 'is' targ say 'nearest' what "to" where 'is' neighbor "with a distance of" dist say nodes-1 'nodes were visited.' /*adjust nodes because of DO loop exit*/ say copies('─',79) return </lang> Output:
target is (9,2) nearest point to target is (5,4) with a distance of 4.4721 3 nodes were visited. ─────────────────────────────────────────────────────────────────────────────── target star is (72974,29757,53483) nearest star to target star is (74876,31074,57820) with a distance of 4915.4514 506 nodes were visited. ─────────────────────────────────────────────────────────────────────────────── cube center is (50000,50000,50000) nearest star to cube center is (44806,50379,54720) with a distance of 7028.4904 494 nodes were visited. ─────────────────────────────────────────────────────────────────────────────── food is (64,19,77,52) nearest fly to food is (50,20,70,58) with a distance of 16.7929 4972 nodes were visited. ───────────────────────────────────────────────────────────────────────────────