Skip to main content

Disjoint Set

The disjoint set is a data structure that keeps track of a partition of a set into disjoint (non-overlapping) subsets.

There are two operations:

  • find(x): Find the root of the set containing x.
  • union(x, y): Union the sets containing x and y.

In this implementation, we use an array to store the parent of each element.

Codes

Initialize Parent Array

int[] parentNode=new int[N];
for(int i=0;i<N;i++){
parentNode[i]=i;
}

Find Root

If two elements a,b have the same root, findParent(a)==findParent(b) should returns true.

private int findParent(int i){
if(parentNode[i]!=i){
parentNode[i]=findParent[parentNode[i]];
}
return parentNode[i];
}

Union

private void union(int a, int b){
int rootA = findParent(a);
int rootB = findParent(b);
parentNode[rootA] = rootB;
}

Applications

Leetcode 3607

This problem is excellent usage of disjoint set.

Intuition

The key idea is to consider the offline procedure in reverse order.

Instead of offline (cutting the graph), we consider the online (union the graph) procedure.

A power station online is equivalent to union the two connected components of the adjacent power stations of the new online power station.

Solution
class Solution:
def processQueries(self, c: int, connections: List[List[int]], queries: List[List[int]]) -> List[int]:
# reverse union find
parent=[i for i in range(c)]
min_value=[c+1 for i in range(c)]
def find(x):
if parent[x]!=x:
parent[x]=find(parent[x])
return parent[x]
def union(x,y):
# set parent of y to x, require x is on.
# print(f'merge{x}:{y}')
x=find(x)
y=find(y)
if x!=y:
# print(f'setp[{y}] to {x}')
parent[y]=x
min_value[x]=min(min_value[x],min_value[y])
on=[0]*c
for code,i in queries:
if code==2:
on[i-1]+=1
# connect initial
adj=[[] for _ in range(c)]
for a,b in connections:
adj[a-1].append(b-1)
adj[b-1].append(a-1)
for i in range(c):
if on[i]==0:
min_value[find(i)]=min(min_value[find(i)],i)
for j in adj[i]:
union(i,j)
result=[]
for code,i in queries[::-1]:
i-=1
if code==1:
# print(min_value,parent,on,i)
if on[i]==0:
result.append(i+1)
else:
val=min_value[find(i)]
# print(val,i)
result.append(val+1 if val!=c+1 else -1)
else:
on[i]-=1
if on[i]==0:
min_value[find(i)]=min(min_value[find(i)],i)
return result[::-1]