import os
import json
import argparse
import networkx as nx

def build_reverse_dependency_graph(head_rules):
    reverse_graph = nx.DiGraph()
    for str_head, bodies in head_rules.items():
        head = int(str_head)
        for body in bodies:
            for atom in body["pos_body"]:
                reverse_graph.add_edge(atom, head)
            for atom in body["neg_body"]:
                reverse_graph.add_edge(atom, head)
    return reverse_graph

def count_descendants(reverse_graph, start_node):
    descendants = nx.descendants(reverse_graph, start_node)
    return len(descendants)

def reassign_ids(program_dict, verbose=False):
    head_rules = {head: bodies for head, bodies in program_dict["head_rules"].items() if head != "0"}
    pfacts = [pf for pf, _ in program_dict["prob"]["pfacts"]]
    ads = [a for ad in program_dict["prob"]["ads"] for a, _ in ad]

    reverse_graph = build_reverse_dependency_graph(head_rules)

    num_atoms = program_dict["metadata"]["num_atoms"]

    all_atoms = set(range(1, num_atoms + 1))
    prob_atoms = set(pfacts + ads)
    logic_atoms = all_atoms - prob_atoms

    prob_count = {atom: count_descendants(reverse_graph, atom) for atom in prob_atoms}
    logic_count = {atom: count_descendants(reverse_graph, atom) for atom in logic_atoms}

    # print the number of descendants for each atom
    if verbose:
        txt2atom = program_dict["atom_mapping"]
        print("Probability atoms:")
        for atom, count in prob_count.items():
            print(f"{txt2atom[str(atom)]}: {count}")
        print("Logic atoms:")
        for atom, count in logic_count.items():
            print(f"{txt2atom[str(atom)]}: {count}")

    prob_sorted = sorted(prob_count, key=prob_count.get, reverse=True)
    logic_sorted = sorted(logic_count, key=logic_count.get, reverse=True)

    # First, we have the probability atoms sorted by the number of descendants;
    # then, we have the logic atoms sorted by the number of descendants.
    sorted_atoms = prob_sorted + logic_sorted

    # Create a mapping from old atoms IDs to the variable order.
    new_id_map = {
        old_id: new_id
        for old_id, new_id in zip(sorted_atoms, range(1, num_atoms + 1))
    }
    # We always map false to 0, as it is not a variable
    new_id_map[0] = 0

    new_head_rules = {}
    for head, bodies in program_dict["head_rules"].items():
        head_id = int(head)
        new_head = new_id_map[head_id] if head_id >= 0 else -new_id_map[-head_id]
        new_bodies = [
            {
                "pos_body": [new_id_map[atom] for atom in body["pos_body"]],
                "neg_body": [new_id_map[atom] for atom in body["neg_body"]]
            }
            for body in bodies
        ] if new_head >= 0 else []
        new_head_rules[new_head] = new_bodies

    new_rules = {
        "normal": [],
        "disjunctive": [],
        "choice": []
    }
    for rule_type in ["normal", "disjunctive", "choice"]:
        for rule in program_dict["rules"][rule_type]:
            new_rule = {
                "head": new_id_map[rule["head"]] if rule_type == "normal" else [new_id_map[a] for a in rule["head"]],
                "body": {
                    "pos": [new_id_map[a] for a in rule["body"]["pos"]],
                    "neg": [new_id_map[a] for a in rule["body"]["neg"]]
                },
                "text": rule["text"]
            }
            if rule_type == "choice":
                new_rule["lower"] = rule["lower"]
                new_rule["upper"] = rule["upper"]
            new_rules[rule_type].append(new_rule)

    new_pfacts = [(new_id_map[atom], prob) for atom, prob in program_dict["prob"]["pfacts"]]
    new_facts = [new_id_map[atom] for atom in program_dict["facts"]]
    new_ads = [[(new_id_map[atom], prob) for atom, prob in ad] for ad in program_dict["prob"]["ads"]]

    new_program_dict = {
        "rules": new_rules,
        "head_rules": new_head_rules,
        "metadata": program_dict["metadata"],
        "prob": {
            "pfacts": new_pfacts,
            "ads": new_ads
        },
        "facts": new_facts,
        "exactly_one_constraints": [[new_id_map[a] for a in constraint] for constraint in program_dict["exactly_one_constraints"]],
        # loop formulas are a list of dictionaries, each with a "loop: (a list
        # of atoms)" and "r_l" (a list of list of literals) keys.
        "loop_formulas": [
            {
                "loop": [new_id_map[atom] for atom in loop["loop"]],
                "r_l": [
                    [
                        new_id_map[abs(literal)] * (-1 if literal < 0 else 1)
                        for literal in r
                    ]
                    for r in loop["r_l"]
                ]
            }
            for loop in program_dict["loop_formulas"]
        ],
    }

    return new_program_dict


def main():
    parser = argparse.ArgumentParser(description="Reassign IDs in a PASP program based on dependency graph.")
    parser.add_argument("input_file", help="Path to the input JSON file")
    parser.add_argument("--verbose", action="store_true", help="Print the number of descendants for each atom")
    args = parser.parse_args()

    fp  = args.input_file
    file_name = os.path.splitext(fp)[0]

    with open(args.input_file, 'r') as f:
        program_dict = json.load(f)

    new_program_dict = reassign_ids(program_dict, verbose=args.verbose)

    with open(f"{file_name}_init.json", 'w') as f:
        json.dump(new_program_dict, f, indent=4)

if __name__ == "__main__":
    main()
