package correct_java_programs;
import java.util.*;

/**
 * Minimum spanning tree
 */
public class MINIMUM_SPANNING_TREE {
    public static Set<WeightedEdge> minimum_spanning_tree(List<WeightedEdge> weightedEdges) {
        Map<Node, Node> parent = new HashMap<>();
        Map<Node, Integer> rank = new HashMap<>();
        
        // Initial setup for disjoint set
        for (WeightedEdge edge : weightedEdges) {
            parent.putIfAbsent(edge.node1, edge.node1);
            parent.putIfAbsent(edge.node2, edge.node2);
            rank.putIfAbsent(edge.node1, 0);
            rank.putIfAbsent(edge.node2, 0);
        }

        // Sort edges based on their weight
        Collections.sort(weightedEdges, Comparator.comparingInt(e -> e.weight));

        Set<WeightedEdge> minSpanningTree = new HashSet<>();

        for (WeightedEdge edge : weightedEdges) {
            Node root1 = find(parent, edge.node1);
            Node root2 = find(parent, edge.node2);

            if (root1 != root2) {
                minSpanningTree.add(edge);
                union(parent, rank, root1, root2);
            }
        }
        return minSpanningTree;
    }

    private static Node find(Map<Node, Node> parent, Node node) {
        if (parent.get(node) != node) {
            parent.put(node, find(parent, parent.get(node)));
        }
        return parent.get(node);
    }

    private static void union(Map<Node, Node> parent, Map<Node, Integer> rank, Node node1, Node node2) {
        Node root1 = find(parent, node1);
        Node root2 = find(parent, node2);

        if (root1 != root2) {
            if (rank.get(root1) < rank.get(root2)) {
                parent.put(root1, root2);
            } else if (rank.get(root1) > rank.get(root2)) {
                parent.put(root2, root1);
            } else {
                parent.put(root2, root1);
                rank.put(root1, rank.get(root1) + 1);
            }
        }
    }
}