Skip to main content

Segment Tree

The segment tree is a data structure that can be used to solve range query problems.

  • Query: returns the maximum value (or other local aggregates) in the range

Range aggregate (sum/min/max), point/range updates; can report intervals covering a point

two operations are supported:

  • 'construct' to build the tree
  • 'query' to query the tree with set of properties

one hidden function to maintain the tree is needed.

Codes

This is the class-free version of the segment tree implemented by myself, you can only use the functions you need to construct the tree and query the tree.

# A segment tree is essentially a list
def create_by_values(seg: list[Any], vals: list[Any], l: int, r: int, segIdx: int=1):
"""
Maintain the node on segment tree at segIdx recursively, l,r inclusive.
:param seg: the segment tree
:param vals: the original list
:param l: the left index of the range
:param r: the right index of the range
:param segIdx: the index of the node to maintain
"""
if l==r:
seg[segIdx]=vals[l]
return
m=(l+r)//2
# maintain the left and right child
create_by_values(seg, vals, l, m, segIdx*2)
create_by_values(seg, vals, m+1, r, segIdx*2+1)
# this keeps the property you want to store in the segment tree
maintain(seg, segIdx)

def maintain(seg: list[Any], segIdx: int):
"""
Maintain the node on segment tree at segIdx
:param seg: the segment tree
:param segIdx: the index of the node to maintain
"""
# EXAMPLE: the max value in the range [l,r]
# seg[segIdx]=max(seg[segIdx*2],seg[segIdx*2+1])
# EXAMPLE: the sum of the values in the range [l,r]
seg[segIdx]=seg[segIdx*2]+seg[segIdx*2+1]

def construct(vals: list[Any]) -> list[Any]:
n = len(vals)
# get normalized tree size, keep tree filled balanced
# tree size is at most 2n to the original size of the list
treeSize = 2 << (n-1).bit_length()
seg = [0] * treeSize
# costruct tree in query range
create_by_values(seg, vals, 1, 0, n-1)
return seg

def query(seg: list[int], l: int, r: int, q: int, segIdx: int=1) -> Any:
"""
Query the segment tree at segIdx recursively
:param seg: the segment tree
:param segIdx: the index of the node to query
:param l: the left index of the range
:param r: the right index of the range
:param q: the query value you need to pass to the query function
:return: self defined return value
"""
# EXAMPLE: query for find minimum index of value greater than q
# if seg[segIdx]<q:
# # no value found
# return -1
# if l==r:
# seg[segIdx] = -1
# # return the root index and mark the value as invalid choice, in this case -1
# return l
m=(l+r)//2
ret = query(seg, l, m, q, segIdx*2)
if ret==-1:
ret=query(seg, m+1, r, q, segIdx*2+1)
# if any value is modified, we need to maintain the tree
maintain(seg, segIdx)
return ret

Description

Runtime analysis

'construct' is O(n)O(n) 'query' is O(logn)O(\log n)

Applications

Leetcode 3479

Now called fruit baskets III.

Details

Intuition The key insights here is to reduce the time of querying and updating the capacity of the baskets with lowest index.

We recode the node value with maximum capacity of the baskets in the range.

To ensure the minimum index is selected, we query the left part first. If the left part is not valid, we query the right part.

Solution
class ListSegTree:
# this is leetcode official version of segment tree modified
# to keep the max value in the segment queried
def __init__ (self,vals):
self.n = len(vals)
# get normalized tree size, keep tree filled balanced
treeSize = 2 << (self.n-1).bit_length()
# store the max value in the segment created
# index 0 is not used since it don't satisfied 2n\neq n
# in index 1, it represent the tree max node value in range [0,n).
# in index 2, it represent the tree max node value in range [0,n//2) and for index 3 it is [n//2,n)
self.seg = [0] * treeSize
# costruct tree in query range
self._construct(vals, 1, 0, self.n-1)

def _maintain(self, segIdx):
# maintain the node on segment tree at segIdx
self.seg[segIdx]=max(self.seg[segIdx*2],self.seg[segIdx*2+1])

def _construct(self, vals, segIdx, l, r):
# construct the segment tree, at segIdx, represented range is [l,r] inclusive
# base case, if reached the leaf node
if l==r:
self.seg[segIdx]=vals[l]
return
m=(l+r)//2
self._construct(vals,segIdx*2, l, m)
self._construct(vals,segIdx*2+1, m+1, r)
self._maintain(segIdx)

def query(self, segIdx, l, r, q):
# return the first index of val greater than q and maintain the tree
if self.seg[segIdx]<q:
# no value found
return -1
if l==r:
# return the root index and mark the value as invalid choice, in this case -1
self.seg[segIdx] = -1
return l
m = (l+r)//2
# find left part first
ret = self.query(segIdx*2, l, m, q)
if ret==-1:
ret=self.query(segIdx*2+1, m+1, r, q)
self._maintain(segIdx)
return ret

class Solution:
def numOfUnplacedFruits(self, fruits: List[int], baskets: List[int]) -> int:
n=len(fruits)
if n==0:
return 0
seg=ListSegTree(baskets)
res=0
for i in fruits:
if seg.query(1,0,n-1,i)==-1:
res+=1
return res