Segment Tree
The segment tree is a data structure that can be used to solve range query problems.
- Query: returns the maximum value (or other local aggregates) in the range
Range aggregate (sum/min/max), point/range updates; can report intervals covering a point
two operations are supported:
- 'construct' to build the tree
- 'query' to query the tree with set of properties
one hidden function to maintain the tree is needed.
A segment tree is a balanced binary tree on the index interval of an array, where each node stores an aggregate (sum, min, max, gcd, ...) of the subarray it represents. The root covers the whole array; every internal node splits its interval in half and delegates to its two children. Any range decomposes into canonical subtree intervals, so range queries and point updates both run in logarithmic time.
The implementation below keeps a sum segment tree and supports two operations: a point update that sets index to a new value, and a range sum query on . It is the bread-and-butter data structure for competitive programming range problems.
Codes
- Python-classfree
- Python-class
- C++-class
- Python
- C++
This is the class-free version of the segment tree implemented by myself, you can only use the functions you need to construct the tree and query the tree.
# A segment tree is essentially a list
def create_by_values(seg: list, vals: list, l: int, r: int, segIdx: int=1):
"""
Maintain the node on segment tree at segIdx recursively, l,r inclusive.
:param seg: the segment tree
:param vals: the original list
:param l: the left index of the range
:param r: the right index of the range
:param segIdx: the index of the node to maintain
"""
if l==r:
seg[segIdx]=vals[l]
return
m=(l+r)//2
# maintain the left and right child
create_by_values(seg, vals, l, m, segIdx*2)
create_by_values(seg, vals, m+1, r, segIdx*2+1)
# this keeps the property you want to store in the segment tree
maintain(seg, segIdx)
def maintain(seg: list, segIdx: int):
"""
Maintain the node on segment tree at segIdx
:param seg: the segment tree
:param segIdx: the index of the node to maintain
"""
# EXAMPLE: the max value in the range [l,r]
# seg[segIdx]=max(seg[segIdx*2],seg[segIdx*2+1])
# EXAMPLE: the sum of the values in the range [l,r]
seg[segIdx]=seg[segIdx*2]+seg[segIdx*2+1]
def construct(vals: list) -> list:
n = len(vals)
# get normalized tree size, keep tree filled balanced
# tree size is at most 2n to the original size of the list
treeSize = 2 << (n-1).bit_length()
seg = [0] * treeSize
# costruct tree in query range
create_by_values(seg, vals, 1, 0, n-1)
return seg
def query(seg: list, l: int, r: int, q: int, segIdx: int=1):
"""
Query the segment tree at segIdx recursively
:param seg: the segment tree
:param segIdx: the index of the node to query
:param l: the left index of the range
:param r: the right index of the range
:param q: the query value you need to pass to the query function
:return: self defined return value
"""
# EXAMPLE: query for find minimum index of value greater than q
# if seg[segIdx]<q:
# # no value found
# return -1
# if l==r:
# seg[segIdx] = -1
# # return the root index and mark the value as invalid choice, in this case -1
# return l
m=(l+r)//2
ret = query(seg, l, m, q, segIdx*2)
if ret==-1:
ret=query(seg, m+1, r, q, segIdx*2+1)
# if any value is modified, we need to maintain the tree
maintain(seg, segIdx)
return ret
class ListSegTree:
# this is leetcode official version of segment tree modified
# to keep the max value in the segment queried
def __init__ (self,vals):
self.n = len(vals)
# get normalized tree size, keep tree filled balanced
treeSize = 2 << (self.n-1).bit_length()
# store the max value in the segment created
# index 0 is not used since it don't satisfied 2n != n
# in index 1, it represent the tree max node value in range [0,n).
# in index 2, it represent the tree max node value in range [0,n//2) and for index 3 it is [n//2,n)
self.seg = [0] * treeSize
# costruct tree in query range
self._construct(vals, 1, 0, self.n-1)
def _maintain(self, segIdx):
# maintain the node on segment tree at segIdx
# TODO: implement this your self
# self.seg[segIdx]=max(self.seg[segIdx*2],self.seg[segIdx*2+1])
pass
def _construct(self, vals, segIdx, l, r):
# construct the segment tree, at segIdx, represented range is [l,r] inclusive
# base case, if reached the leaf node
if l==r:
self.seg[segIdx]=vals[l]
return
m=(l+r)//2
self._construct(vals,segIdx*2, l, m)
self._construct(vals,segIdx*2+1, m+1, r)
self._maintain(segIdx)
def query(self, segIdx, l, r, q):
# return the first index of val greater than q and maintain the tree
if self.seg[segIdx]<q:
# no value found
# TODO: implement this your self
# return -1
pass
if l==r:
# return the root index and mark the value as invalid choice, in this case -1
# TODO: implement this your self
# self.seg[segIdx] = -1
return l
m = (l+r)//2
# find left part first
ret = self.query(segIdx*2, l, m, q)
if ret==-1:
ret=self.query(segIdx*2+1, m+1, r, q)
self._maintain(segIdx)
return ret
class ListSegTree{
private:
int n;
vector<int> seg;
void _maintain(int segIdx){
// maintain the node on segment tree at segIdx
// TODO: implement this your self
// seg[segIdx]=max(seg[segIdx*2],seg[segIdx*2+1]);
}
void _construct(vector<int>& vals, int segIdx, int l, int r){
// construct the segment tree, at segIdx, represented range is [l,r] inclusive
// base case, if reached the leaf node
if(l==r){
seg[segIdx]=vals[l];
return;
}
int m=(l+r)/2;
_construct(vals,segIdx*2,l,m);
_construct(vals,segIdx*2+1,m+1,r);
_maintain(segIdx);
}
public:
// this is leetcode official version of segment tree modified
// to keep the max value in the segment queried
ListSegTree(vector<int>& vals){
n=vals.size();
// get normalized tree size, keep tree filled balanced
int treeSize=2<<(n-1).bit_length();
// store the max value in the segment created
// index 0 is not used since it don't satisfied 2n != n
// in index 1, it represent the tree max node value in range [0,n).
// in index 2, it represent the tree max node value in range [0,n//2) and for index 3 it is [n//2,n)
seg=vector<int>(treeSize,0);
// costruct tree in query range
_construct(vals,1,0,n-1);
}
int query(int segIdx, int l, int r, int q){
// return the first index of val greater than q and maintain the tree
if(seg[segIdx]<q){
// no value found
// TODO: implement this your self
// return -1;
}
if(l==r){
// return the root index and mark the value as invalid choice, in this case -1
// TODO: implement this your self
// seg[segIdx] = -1
return l;
}
int m=(l+r)/2;
int ret=query(segIdx*2,l,m,q);
if(ret==-1){
ret=query(segIdx*2+1,m+1,r,q);
}
_maintain(segIdx);
return ret;
}
};
def solve(xs):
array, queries = xs
n = len(array)
size = 1
while size < max(n, 1):
size *= 2
seg = [0] * (2 * size)
for i, v in enumerate(array):
seg[size + i] = v
for i in range(size - 1, 0, -1):
seg[i] = seg[2 * i] + seg[2 * i + 1]
def update(i, v):
i += size
seg[i] = v
i //= 2
while i:
seg[i] = seg[2 * i] + seg[2 * i + 1]
i //= 2
def query(l, r):
res = 0
l += size
r += size + 1
while l < r:
if l & 1:
res += seg[l]
l += 1
if r & 1:
r -= 1
res += seg[r]
l //= 2
r //= 2
return res
out = []
for q in queries:
if q[0] == "update":
update(q[1], q[2])
else:
out.append(query(q[1], q[2]))
return out
#include <vector>
struct SegTree {
int size;
std::vector<long long> seg;
SegTree(const std::vector<int>& a) {
size = 1;
while (size < (int)a.size() || size < 1) size *= 2;
seg.assign(2 * size, 0);
for (int i = 0; i < (int)a.size(); ++i) seg[size + i] = a[i];
for (int i = size - 1; i >= 1; --i) seg[i] = seg[2*i] + seg[2*i+1];
}
void update(int i, long long v) {
for (seg[i += size] = v, i /= 2; i; i /= 2)
seg[i] = seg[2*i] + seg[2*i+1];
}
long long query(int l, int r) const {
long long res = 0;
for (l += size, r += size + 1; l < r; l /= 2, r /= 2) {
if (l & 1) res += seg[l++];
if (r & 1) res += seg[--r];
}
return res;
}
};
Example
Description
Run time analysis
Build is bottom-up, and both point update and range query run in . A range touches at most two nodes per level of the tree, so the iterative "climb from both ends" loop visits nodes.
Space analysis
. The iterative implementation pads the array up to the next power of two and stores the segments in a flat array of twice that padded length.
Proof of correctness
A range decomposes uniquely into the canonical intervals visited by the iterative loop: at each level, the leftmost and rightmost uncovered cells are absorbed into the result if they are right/left children of their parents, after which the loop moves up to the parents. Every element of is accounted for exactly once because the left pointer only advances past covered cells and the right pointer only retreats past them. Point updates restore the sum invariant along the single root-to-leaf path that contains the updated index, leaving all other paths unchanged.
Extensions
Applications
Leetcode 3479
Now called fruit baskets III.
Details
Intuition
The key insights here is to reduce the time of querying and updating the capacity of the baskets with lowest index.We recode the node value with maximum capacity of the baskets in the range.
To ensure the minimum index is selected, we query the left part first. If the left part is not valid, we query the right part.
Solution
- Python
- C++
class ListSegTree:
# this is leetcode official version of segment tree modified
# to keep the max value in the segment queried
def __init__ (self,vals):
self.n = len(vals)
# get normalized tree size, keep tree filled balanced
treeSize = 2 << (self.n-1).bit_length()
# store the max value in the segment created
# index 0 is not used since it don't satisfied 2n != n
# in index 1, it represent the tree max node value in range [0,n).
# in index 2, it represent the tree max node value in range [0,n//2) and for index 3 it is [n//2,n)
self.seg = [0] * treeSize
# costruct tree in query range
self._construct(vals, 1, 0, self.n-1)
def _maintain(self, segIdx):
# maintain the node on segment tree at segIdx
self.seg[segIdx]=max(self.seg[segIdx*2],self.seg[segIdx*2+1])
def _construct(self, vals, segIdx, l, r):
# construct the segment tree, at segIdx, represented range is [l,r] inclusive
# base case, if reached the leaf node
if l==r:
self.seg[segIdx]=vals[l]
return
m=(l+r)//2
self._construct(vals,segIdx*2, l, m)
self._construct(vals,segIdx*2+1, m+1, r)
self._maintain(segIdx)
def query(self, segIdx, l, r, q):
# return the first index of val greater than q and maintain the tree
if self.seg[segIdx]<q:
# no value found
return -1
if l==r:
# return the root index and mark the value as invalid choice, in this case -1
self.seg[segIdx] = -1
return l
m = (l+r)//2
# find left part first
ret = self.query(segIdx*2, l, m, q)
if ret==-1:
ret=self.query(segIdx*2+1, m+1, r, q)
self._maintain(segIdx)
return ret
class Solution:
def numOfUnplacedFruits(self, fruits: List[int], baskets: List[int]) -> int:
n=len(fruits)
if n==0:
return 0
seg=ListSegTree(baskets)
res=0
for i in fruits:
if seg.query(1,0,n-1,i)==-1:
res+=1
return res
class ListSegTree{
private:
// this is leetcode official version of segment tree modified
int n;
vector<int> seg;
void _maintain(int segIdx){
// maintain the node on segment tree at segIdx
seg[segIdx]=max(seg[segIdx*2],seg[segIdx*2+1]);
}
void _construct(vector<int>& vals, int segIdx, int l, int r){
// construct the segment tree, at segIdx, represented range is [l,r] inclusive
// base case, if reached the leaf node
if(l==r){
seg[segIdx]=vals[l];
return;
}
int m=(l+r)/2;
_construct(vals,segIdx*2,l,m);
_construct(vals,segIdx*2+1,m+1,r);
_maintain(segIdx);
}
public:
// to keep the max value in the segment queried
ListSegTree(vector<int>& vals){
n=vals.size();
// get normalized tree size, keep tree filled balanced
int treeSize = 1;
while (treeSize < n) treeSize <<= 1;
treeSize <<= 1;
// store the max value in the segment created
// index 0 is not used since it don't satisfied 2n != n
// in index 1, it represent the tree max node value in range [0,n).
// in index 2, it represent the tree max node value in range [0,n//2) and for index 3 it is [n//2,n)
seg=vector<int>(treeSize,0);
// costruct tree in query range
_construct(vals,1,0,n-1);
}
int query(int segIdx, int l, int r, int q){
// return the first index of val greater than q and maintain the tree
if(seg[segIdx]<q){
// no value found
return -1;
}
if(l==r){
// return the root index and mark the value as invalid choice, in this case -1
seg[segIdx] = -1;
return l;
}
int m=(l+r)/2;
int ret=query(segIdx*2,l,m,q);
if(ret==-1){
ret=query(segIdx*2+1,m+1,r,q);
}
_maintain(segIdx);
return ret;
}
};
class Solution{
public:
int numOfUnplacedFruits(vector<int>& fruits, vector<int>& baskets){
int n=fruits.size();
if(n==0){
return 0;
}
ListSegTree seg=ListSegTree(baskets);
int res=0;
for(int i=0;i<n;i++){
if(seg.query(1,0,n-1,fruits[i])==-1){
res++;
}
}
return res;
}
}
References
- de Berg, Cheong, van Kreveld, Overmars. Computational Geometry: Algorithms and Applications, 3rd ed., Chapter 10.
- Bentley, J. L. "Solutions to Klee's rectangle problems." Technical report, 1977.
- cp-algorithms — Segment Tree