"""Simple union-find by size. Suitable only for small sets.

Find follows the path compression approach.

https://en.wikipedia.org/wiki/Disjoint-set_data_structure
"""
from dataclasses import dataclass

Vertex = int


@dataclass
class VertexInfo:
    parent: Vertex
    size: int


Partition = dict[Vertex, VertexInfo]


def make_set(x: Vertex, partition: Partition):
    if x not in partition.keys():
        partition[x] = VertexInfo(parent=x, size=0)


def find(x: Vertex, partition: Partition) -> Vertex:
    if partition[x].parent != x:
        partition[x].parent = find(partition[x].parent, partition)
        return partition[x].parent
    else:
        return x


def union(x: Vertex, y: Vertex, partition: Partition):
    x = find(x, partition)
    y = find(y, partition)

    # Check whether x and y are already in the same set
    if x == y:
        return

    # If necessary, rename variables to ensure that
    # x has at least as many descendants as y
    if partition[x].size < partition[y].size:
        x, y = y, x

    # Make x the new root
    partition[y].parent = x

    # Update size of x
    partition[x].size = partition[x].size+partition[y].size
