Range Tree
Range tree is an orthogonal range reporting/counting (rectangles)
- Query: returns the number of points in the range
Complexity:
2D: , with fractional cascading ; build ; query is logarithmic raised to the power of plus the output size, where is the dimension of the points.
Build time:
Query time: logarithmic raised to the power of plus the output size , where is the dimension of the points.
Hierarchical trees per dimension. Excellent for exact orthogonal range queries in low dimensions.
A range tree is a balanced binary search tree on the -coordinates of a static point set, where every internal node also stores an auxiliary structure on the -coordinates of the points in its subtree. In two dimensions the auxiliary structure is just a sorted array that supports binary search. An orthogonal range query first locates the canonical -subtrees that together cover the query's -interval, then reports points inside the -range by binary searching into each canonical subtree's -array.
Compared with a kd-tree, the range tree trades extra space () for a sharper query bound of , and with fractional cascading it drops to . It is the canonical data structure presented in de Berg et al. for orthogonal range reporting in low dimensions.
Codes
1D Range Tree
from bisect import bisect_left, bisect_right
class RangeTree:
class _Node:
__slots__ = ("dim", "minv", "maxv", "left", "right", "assoc", "vals", "pts")
def __init__(self, dim, minv, maxv, left, right, assoc, vals, pts):
self.dim = dim
# stores the min vertex and max vertex
self.minv = minv
self.maxv = maxv
self.left = left
self.right = right
# subtree handling remaining dimensions
self.assoc = assoc
# only at last dimension: sorted values on this dim
self.vals = vals
# only at last dimension: points sorted by this dim
self.pts = pts
def __init__(self, points):
if not points:
self.d = 0
self.root = None
return
self.d = len(points[0])
for p in points:
if len(p) != self.d:
raise ValueError("All points must have the same dimension")
self.root = self._build(points, 0)
def _build(self, pts, dim):
if not pts:
return None
pts_sorted = sorted(pts, key=lambda p: p[dim])
minv, maxv = pts_sorted[0][dim], pts_sorted[-1][dim]
if len(pts_sorted) == 1:
left = right = None
else:
mid = len(pts_sorted) // 2
left = self._build(pts_sorted[:mid], dim)
right = self._build(pts_sorted[mid:], dim)
if dim == self.d - 1:
vals = [p[dim] for p in pts_sorted]
assoc = None
return self._Node(dim, minv, maxv, left, right, assoc, vals, pts_sorted)
else:
assoc = self._build(pts_sorted, dim + 1)
return self._Node(dim, minv, maxv, left, right, assoc, None, None)
def query(self, lows, highs):
if self.root is None:
return []
if len(lows) != self.d or len(highs) != self.d:
raise ValueError("Bounds must match point dimension")
res = []
self._query_node(self.root, lows, highs, res)
return res
def _query_node(self, node, L, H, out):
if node is None:
return
d = node.dim
if node.maxv < L[d] or node.minv > H[d]:
return
if L[d] <= node.minv and node.maxv <= H[d]:
if d == self.d - 1:
i = bisect_left(node.vals, L[d])
j = bisect_right(node.vals, H[d])
out.extend(node.pts[i:j])
else:
self._query_node(node.assoc, L, H, out)
return
self._query_node(node.left, L, H, out)
self._query_node(node.right, L, H, out)
def query_count(self, lows, highs):
if self.root is None:
return 0
if len(lows) != self.d or len(highs) != self.d:
raise ValueError("Bounds must match point dimension")
return self._query_count_node(self.root, lows, highs)
def _query_count_node(self, node, L, H):
if node is None:
return 0
d = node.dim
if node.maxv < L[d] or node.minv > H[d]:
return 0
if L[d] <= node.minv and node.maxv <= H[d]:
if d == self.d - 1:
i = bisect_left(node.vals, L[d])
j = bisect_right(node.vals, H[d])
return j - i
else:
return self._query_count_node(node.assoc, L, H)
return
return self._query_count_node(node.left, L, H) + self._query_count_node(node.right, L, H)
compact 2d version:
class RangeTree2D:
class _Node:
__slots__ = ("x_min", "x_max", "y_vals", "y_points", "left", "right")
def __init__(self, x_min, x_max, y_vals, y_points, left, right):
self.x_min = x_min
self.x_max = x_max
self.y_vals = y_vals # sorted list of y's
self.y_points = y_points # same points sorted by y (aligned with y_vals)
self.left = left
self.right = right
def __init__(self, points):
"""
points: iterable of (x, y). Static set (no updates).
"""
pts = sorted(points, key=lambda p: p[0])
self.root = self._build(pts)
def _build(self, pts):
if not pts:
return None
x_min, x_max = pts[0][0], pts[-1][0]
y_sorted = sorted(pts, key=lambda p: p[1])
y_vals = [p[1] for p in y_sorted]
if len(pts) == 1:
return self._Node(x_min, x_max, y_vals, y_sorted, None, None)
mid = len(pts) // 2
left = self._build(pts[:mid])
right = self._build(pts[mid:])
return self._Node(x_min, x_max, y_vals, y_sorted, left, right)
def query(self, x1, x2, y1, y2):
"""
Returns all points (x, y) with x1 <= x <= x2 and y1 <= y <= y2.
"""
if self.root is None:
return []
if x2 < x1 or y2 < y1:
return []
res = []
def dfs(node):
if node is None:
return
if node.x_max < x1 or node.x_min > x2:
return
if x1 <= node.x_min and node.x_max <= x2:
i = bisect_left(node.y_vals, y1)
j = bisect_right(node.y_vals, y2)
res.extend(node.y_points[i:j])
return
dfs(node.left)
dfs(node.right)
dfs(self.root)
return res
def count(self, x1, x2, y1, y2):
"""
Counts points (x, y) in the rectangle [x1, x2] x [y1, y2].
"""
if self.root is None or x2 < x1 or y2 < y1:
return 0
total = 0
def dfs(node):
nonlocal total
if node is None:
return
if node.x_max < x1 or node.x_min > x2:
return
if x1 <= node.x_min and node.x_max <= x2:
i = bisect_left(node.y_vals, y1)
j = bisect_right(node.y_vals, y2)
total += (j - i)
return
dfs(node.left)
dfs(node.right)
dfs(self.root)
return total
The code above is usable and one typical usage can be found in leetcode 3027, even though it is not the optimal solution for problem like that.
The script is created to solve the problem above, which can be found in leetcode 3027
It works.
Solve-style implementations
- Python
- C++
from bisect import bisect_left, bisect_right
def solve(xs):
points, queries = xs
pts = sorted(range(len(points)), key=lambda i: points[i][0])
def build(lo, hi):
if lo >= hi:
return None
node = {"lo": lo, "hi": hi}
sub = sorted(pts[lo:hi], key=lambda i: points[i][1])
node["ys"] = [points[i][1] for i in sub]
node["ids"] = sub
if hi - lo > 1:
mid = (lo + hi) // 2
node["left"] = build(lo, mid)
node["right"] = build(mid, hi)
else:
node["left"] = node["right"] = None
return node
root = build(0, len(pts))
def report(node, y1, y2, out):
if node is None:
return
i = bisect_left(node["ys"], y1)
j = bisect_right(node["ys"], y2)
out.extend(node["ids"][i:j])
def x_range(node, x1, x2, y1, y2, out):
if node is None:
return
lo_x = points[pts[node["lo"]]][0]
hi_x = points[pts[node["hi"] - 1]][0]
if hi_x < x1 or lo_x > x2:
return
if x1 <= lo_x and hi_x <= x2:
report(node, y1, y2, out)
return
x_range(node["left"], x1, x2, y1, y2, out)
x_range(node["right"], x1, x2, y1, y2, out)
result = []
for x1, y1, x2, y2 in queries:
out = []
x_range(root, x1, x2, y1, y2, out)
result.append(sorted(out))
return result
#include <algorithm>
#include <vector>
struct RangeTree {
std::vector<std::pair<int,int>> pts;
struct Node {
int lo, hi;
std::vector<int> ys;
std::vector<int> ids;
Node *l = nullptr, *r = nullptr;
};
std::vector<int> order;
Node* root = nullptr;
Node* build(int lo, int hi) {
if (lo >= hi) return nullptr;
auto* n = new Node{lo, hi};
std::vector<int> sub(order.begin() + lo, order.begin() + hi);
std::sort(sub.begin(), sub.end(),
[&](int a, int b) { return pts[a].second < pts[b].second; });
n->ids = sub;
for (int i : sub) n->ys.push_back(pts[i].second);
if (hi - lo > 1) {
int mid = (lo + hi) / 2;
n->l = build(lo, mid);
n->r = build(mid, hi);
}
return n;
}
void report(Node* n, int y1, int y2, std::vector<int>& out) const {
auto i = std::lower_bound(n->ys.begin(), n->ys.end(), y1) - n->ys.begin();
auto j = std::upper_bound(n->ys.begin(), n->ys.end(), y2) - n->ys.begin();
out.insert(out.end(), n->ids.begin() + i, n->ids.begin() + j);
}
void query(Node* n, int x1, int x2, int y1, int y2,
std::vector<int>& out) const {
if (!n) return;
int loX = pts[order[n->lo]].first;
int hiX = pts[order[n->hi - 1]].first;
if (hiX < x1 || loX > x2) return;
if (x1 <= loX && hiX <= x2) { report(n, y1, y2, out); return; }
query(n->l, x1, x2, y1, y2, out);
query(n->r, x1, x2, y1, y2, out);
}
};
Example
Description
Run time analysis
Build is and each 2D orthogonal range report takes time, where is the number of reported points. One factor comes from selecting the canonical -subtrees and the second from binary searching each subtree's -array.
Space analysis
. Every point appears in the auxiliary -array of each of its ancestors in the -tree.
Proof of correctness
A point lies in the query rectangle iff its -coordinate lies in the -interval and its -coordinate lies in the -interval. The canonical subtree decomposition partitions the -interval into disjoint subtrees whose union is exactly the set of points whose -coordinate is in range. Within each canonical subtree, the sorted -array plus binary search reports exactly the points whose -coordinate is in range. Since the canonical subtrees are disjoint, each qualifying point is reported exactly once. A full proof appears in de Berg et al., Chapter 5.
Extensions
Applications
References
- de Berg, Cheong, van Kreveld, Overmars. Computational Geometry: Algorithms and Applications, 3rd ed., Chapter 5 (Orthogonal Range Searching).
- Bentley, J. L. "Decomposable searching problems." Information Processing Letters, 1979.
- cp-algorithms — Range Tree