"""Algorithms for finding the disjoint set of a graph.
"""

from dataclasses import dataclass
from typing import Optional, Callable, List


@dataclass
class Point:
    value: int
    parent: int
    
    
def find(points: List[Point], cp: int) -> int:
    """Find the root of the disjoint set
    of the point with the given value.
    
    with path compression.
    """

    if points[cp].parent != points[cp].value:
        points[cp].parent = find(points, points[cp].parent)
        return points[cp].parent
    
    return cp
        

def union(points: List[Point], a: int, b: int) -> None:
    """Union two disjoint sets
    by setting the root of one set
    to the root of the other set.
    """
    
    root_a = find(points, a)
    root_b = find(points, b)
    
    points[root_a].parent = root_b

    
def find_disjoint_sets(
    num_points: int,
    criteria: Callable[[int, int], bool],
) -> List[List[int]]:
    """Given a list of values and a judger (criteria)
    for determining if two values are in the same set,
    returns the disjoint sets of the values.
    
    Orders not guaranteed.
    """
    
    # init points
    points = [Point(value, value) for value in range(num_points)]

    for i in range(num_points):
        for j in range(num_points):
            if i != j and criteria(i, j):
                union(points, i, j)
                
    grouped = {}

    for i in range(num_points):
        root = find(points, i)
        
        if root not in grouped:
            grouped[root] = []
            
        grouped[root].append(i)
        
    return list(grouped.values())