Skip to main content

Disjoint Set

The disjoint set is a data structure that keeps track of a partition of a set into disjoint (non-overlapping) subsets.

There are two operations:

  • find(x): Find the root of the set containing x.
  • union(x, y): Union the sets containing x and y.

In this implementation, we use an array to store the parent of each element.

A disjoint set union (DSU), also known as union-find, maintains a partition of nn elements into disjoint groups under two operations: find(x) returns a canonical representative of the group containing xx, and union(x, y) merges the two groups containing xx and yy. Two elements are in the same group if and only if they have the same representative.

Each group is stored as a rooted tree whose nodes point to their parent and whose root points to itself. find walks parent pointers to the root; union links the root of one tree under the root of another. With two standard optimizations — path compression in find and union by rank (or size) in union — the amortized cost per operation becomes essentially constant, making DSU the data structure of choice for problems that need to answer connectivity queries interleaved with edge additions.

Codes

Initialize Parent Array

int[] parentNode=new int[N];
for(int i=0;i<N;i++){
parentNode[i]=i;
}

Find Root

If two elements a,b have the same root, findParent(a)==findParent(b) should returns true.

private int findParent(int i){
if(parentNode[i]!=i){
parentNode[i]=findParent[parentNode[i]];
}
return parentNode[i];
}

Union

private void union(int a, int b){
int rootA = findParent(a);
int rootB = findParent(b);
parentNode[rootA] = rootB;
}

Full DSU class

class DSU:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n

def find(self, x):
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]] # path halving
x = self.parent[x]
return x

def union(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return False
if self.rank[ra] < self.rank[rb]:
ra, rb = rb, ra
self.parent[rb] = ra
if self.rank[ra] == self.rank[rb]:
self.rank[ra] += 1
return True

def solve(xs):
n, ops = xs
dsu = DSU(n)
answers = []
for op in ops:
if op[0] == "union":
dsu.union(op[1], op[2])
else: # "find"
answers.append(dsu.find(op[1]))
return answers

Example

Loading Python runner...

Description

Run time analysis

O(α(n))O(\alpha(n)) amortized per operation, where α\alpha is the inverse Ackermann function. A sequence of mm operations on nn elements runs in O(mα(n))O(m \alpha(n)) total, which is effectively linear for every nn that fits in the physical universe.

Space analysis

O(n)O(n). Two arrays of length nn hold the parent pointers and the rank of each element; no recursion is used.

Proof of correctness

Correctness of the partition is a straightforward invariant: after any sequence of operations, the parent pointers form a forest whose trees are exactly the current equivalence classes, because union merges two trees only at their roots and find never changes which root a node ultimately points to. The amortized O(α(n))O(\alpha(n)) bound is due to Tarjan and van Leeuwen: union by rank guarantees that a tree of rank rr has at least 2r2^r nodes, so rank is at most log2n\log_2 n; combining this height bound with a potential-function argument over compressed ancestor chains yields the inverse Ackermann bound. A full proof is in Tarjan's 1975 paper.

Extensions

Applications

Leetcode 3607

This problem is excellent usage of disjoint set.

Intuition

The key idea is to consider the offline procedure in reverse order.

Instead of offline (cutting the graph), we consider the online (union the graph) procedure.

A power station online is equivalent to union the two connected components of the adjacent power stations of the new online power station.

Solution
class Solution:
def processQueries(self, c: int, connections: List[List[int]], queries: List[List[int]]) -> List[int]:
# reverse union find
parent=[i for i in range(c)]
min_value=[c+1 for i in range(c)]
def find(x):
if parent[x]!=x:
parent[x]=find(parent[x])
return parent[x]
def union(x,y):
# set parent of y to x, require x is on.
# print(f'merge{x}:{y}')
x=find(x)
y=find(y)
if x!=y:
# print(f'setp[{y}] to {x}')
parent[y]=x
min_value[x]=min(min_value[x],min_value[y])
on=[0]*c
for code,i in queries:
if code==2:
on[i-1]+=1
# connect initial
adj=[[] for _ in range(c)]
for a,b in connections:
adj[a-1].append(b-1)
adj[b-1].append(a-1)
for i in range(c):
if on[i]==0:
min_value[find(i)]=min(min_value[find(i)],i)
for j in adj[i]:
union(i,j)
result=[]
for code,i in queries[::-1]:
i-=1
if code==1:
# print(min_value,parent,on,i)
if on[i]==0:
result.append(i+1)
else:
val=min_value[find(i)]
# print(val,i)
result.append(val+1 if val!=c+1 else -1)
else:
on[i]-=1
if on[i]==0:
min_value[find(i)]=min(min_value[find(i)],i)
return result[::-1]

References

  • Cormen, Leiserson, Rivest, Stein. Introduction to Algorithms, 3rd ed., Chapter 21 (Data Structures for Disjoint Sets).
  • Tarjan, R. E. "Efficiency of a Good But Not Linear Set Union Algorithm". JACM 22(2), 1975.
  • cp-algorithms — Disjoint Set Union