import os
import json
import argparse
import numpy as np
import networkx as nx
from collections import defaultdict

def parse_json(json_fp):
    with open(json_fp, "r") as file:
        return json.load(file)

def build_graph(head_rules, num_atoms):
    # Create a graph using NetworkX
    graph = nx.Graph()

    # Build nodes as tuples (index_of_head_atom, set_of_all_atoms_in_rule)
    nodes = {}
    for i, (str_head, bodies) in enumerate(head_rules.items()):
        head = int(str_head)
        # Facts have negative signed heads, so we take the absolute value
        # to cope with this corner case
        atom_set = {abs(head)}
        for body in bodies:
            atom_set.update(body["pos_body"])
            atom_set.update(body["neg_body"])
        nodes[i] = (head, atom_set)
        graph.add_node(i, head=head, atom_set=atom_set)

    # Add edges between nodes that share atoms in their sets
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            if nodes[i][1] & nodes[j][1]:  # Check for intersection of atom sets
                graph.add_edge(i, j)

    return graph, nodes

def find_disjoint_groups(graph, nodes):
    # Check if the graph has more than one connected component
    components = list(nx.connected_components(graph))
    if len(components) > 1:
        # If there are multiple components, use them as the groups
        groups = [[nodes[node][0] for node in component] for component in components]
    else:
        # Apply a k-vertex cut algorithm if the graph is a single component
        k_cut = nx.minimum_node_cut(graph)
        graph.remove_nodes_from(k_cut)

        # Get the remaining connected components
        remaining_components = list(nx.connected_components(graph))

        # Create groups based on the connected components and removed nodes
        groups = [[nodes[node][0] for node in component] for component in remaining_components]
        groups.append([nodes[node][0] for node in k_cut])  # Add removed nodes as a separate group

    return groups

def main():
    parser = argparse.ArgumentParser(description="Split rules into disjoint groups based on non-shared atoms.")
    parser.add_argument("input_file", help="Path to the input JSON file")
    parser.add_argument("--verbose", action="store_true", help="Print detailed output")
    args = parser.parse_args()

    fp = args.input_file
    file_name = os.path.splitext(fp)[0]

    data = parse_json(fp)
    head_rules = {head: bodies for head, bodies in data["head_rules"].items() if head != "0"}
    num_atoms = data["metadata"]["num_atoms"]

    graph, nodes = build_graph(head_rules, num_atoms)

    groups = find_disjoint_groups(graph, nodes)
    if args.verbose:
        print(f"Number of disjoint groups: {len(groups)}")
        print(f"Groups: {groups}")

    # If the head == "0" is present, it represents the cardinality constraints
    # and it belongs to the last group
    if "0" in data["head_rules"].keys():
        groups[-1].append(0)

    data["disjoint_rules"] = groups

    with open(f"{file_name}_non-incremental.json", "w") as f:
        json.dump(data, f, indent=4)

if __name__ == "__main__":
    main()
