import os
import time
import argparse
import json
from pysdd.sdd import SddManager, Vtree
from collections import deque
from typing import List, Tuple, Dict
from tqdm import tqdm
from memory_profiler import memory_usage

def parse_input(data: Dict) -> Tuple[Dict[int, List[Dict[str, List[int]]]], List[int], List[int], int, List[List[int]], List[Dict[str, List[int]]]]:
    """
    The input data should have the following structure:
        - `rules`: a dictionary where the keys are integers representing the
        head atom of a rule, and their respective values are lists of
        dictionaries, where each dictionary has two keys, `pos_body` and
        `neg_body`, which are lists of integers representing the atoms in the
        positive and negative bodies of the rule, respectively.
        - `prob`: a dictionary with two keys, `pfacts` and `ads`, which are both
        lists of integers, representing the IDs of atoms that are present in
        probabilistic facts or annotated disjunctions, respectively.
        - `metadata`: a dictionary containing `num_atoms`, an integer
        representing the total number of atoms in the program.
        - `disjoint_rules`: (optional) a list of lists of integers, where each
        list represents the head atoms of the rules in each group.

    Parameters
    ----------
    data : Dict
        A dictionary containing the program data.

    Returns
    -------
    rules : Dict[int, Dict[str, List[Dict[str, List[int]]]]
        A dictionary where the keys are integers representing the ID of the head
        atom of a set of rule, and the values are dictionaries with two keys,
        `pos_body` and `neg_body`, which are lists of lists of integers, where
        each list of integers represents a positive or negative body of a
        rule.
    prob_ids : List[int]
        A list of integers representing the IDs of atoms that are present in
        probabilistic facts or annotated disjunctions.
    num_atoms : int
        The total number of atoms in the program.
    grouped_rules : List[List[int]]
        A list of lists of integers, where each list represents the head atoms
        of the rules in each group. This list should be ordered by the size
       (number of elements) of the groups, from largest to smallest.
    """
    prob_ids = [pf[0] for pf in data["prob"]["pfacts"]] + [ad[0] for ad_list in data["prob"]["ads"] for ad in ad_list]
    facts_ids = [fact for fact in data["facts"]]
    # rules is equal to data["head_rules"], but with the key as int
    rules = {
        int(head): [{
            "pos_body": body["pos_body"],
            "neg_body": body["neg_body"]
        } for body in bodies]
        for head, bodies in data["head_rules"].items()
    }
    grouped_rules = data.get("disjoint_rules", [[int(head) for head in data["head_rules"].keys()]])

    loop_formulas = data.get("loop_formulas", [])

    return rules, prob_ids, facts_ids, data["metadata"]["num_atoms"], grouped_rules, loop_formulas

def compile_body(pos_body: List[int], neg_body: List[int], manager: SddManager) -> SddManager:
    """
    Compile a single body of a rule into an SDD.

    Parameters
    ----------
    pos_body : List[int]
        A list of integers representing the atoms in the positive body of the
        rule.

    neg_body : List[int]
        A list of integers representing the atoms in the negative body of the
        rule.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The compiled SDD for the body.
    """
    circuit = manager.true()

    # If the body is empty, return True
    if (not pos_body) and (not neg_body):
        return circuit
    assert all(literal != 0 for literal in pos_body)
    assert all(literal != 0 for literal in neg_body)

    for atom in pos_body:
        circuit &= manager.literal(atom)
    for atom in neg_body:
        circuit &= manager.literal(-atom)

    return circuit

def compile_bodies(bodies: List[Dict[str, List[int]]], manager: SddManager) -> SddManager:
    """
    Compile multiple bodies of a rule into an SDD.

    Parameters
    ----------
    bodies : List[Dict[str, List[int]]]
        A list of dictionaries, where each dictionary has two keys, `pos_body`
        and `neg_body`, which are lists of integers representing the atoms in
        the positive and negative bodies of the rule, respectively.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The compiled SDD for the bodies.
    """
    circuit = manager.false()

    # if there are no bodies, return False
    if not bodies:
        return circuit

    for body in bodies:
        circuit |= compile_body(body["pos_body"], body["neg_body"], manager)

    return circuit


def compile_head(head: int, bodies: List[Dict[str, List[int]]], manager: SddManager) -> SddManager:
    """
    Compile the head of a rule into an SDD.

    Parameters
    ----------
    head : int
        An integer representing the head atom of the rule.

    bodies : List[Dict[str, List[int]]]
        A list of dictionaries, where each dictionary has two keys, `pos_body`
        and `neg_body`, which are lists of integers representing the atoms in
        the positive and negative bodies of the rule, respectively.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The compiled SDD for the head.
    """
    # Integrity constraint
    if head == 0:
        return ~compile_bodies(bodies, manager)

    # If the body is empty, we have two possibilities:
    # 1. The head is a fact (it is true)
    #    In this case, we assume that the head is negated,
    #    so this would correspond to `~head <==> FALSE `,
    #    which is equivalent to saying that the head is true.
    # 2. The head is false because it is neither a fact, nor
    #    head of a rule.
    if not bodies:
        return manager.literal(-head)

    head_sdd = manager.literal(head)
    bodies_sdd = compile_bodies(bodies, manager)

    # Compiles the `if and only if` between the head and the bodies disjunction
    return (head_sdd & bodies_sdd) | (~head_sdd & ~bodies_sdd)

def compile_facts(facts: List[int], manager: SddManager) -> SddManager:
    """
    Compile facts into an SDD.

    Parameters
    ----------
    facts : List[int]
        A list of integers representing the atoms in the facts.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The compiled SDD for the program.
    """
    circuit = manager.true()
    for f in facts:
        circuit &= manager.literal(f)
    return circuit

def compile_program(circuit: SddManager, rules: List[Tuple[int, List[Dict[str, List[int]]]]], manager: SddManager, minimize: bool = False, alpha: float = 2) -> SddManager:
    """
    Compile a program into an SDD.

    Parameters
    ----------
    rules : List[Tuple[int, List[Dict[str, List[int]]]]
        A list of tuples where the first element is an integer representing the
        head atom of a rule, and the second element is a list of dictionaries,
        where each dictionary has two keys, `pos_body` and `neg_body`, which are
        lists of integers representing the atoms in the positive and negative
        bodies of the rule, respectively.

    manager : SddManager
        The SDD manager.

    minimize : bool, optional
        Whether to apply dynamic minimization to the SDD (default is False).

    alpha : float, optional
        The threshold that determines when a circuit has grown too much and
        should be minimized (default is 2).

    Returns
    -------
    SddManager
        The compiled SDD for the program.
    """
    curr_size = circuit.size()
    for head, bodies in tqdm(rules):
        circuit &= compile_head(head, bodies, manager)
        circuit.ref()
        manager.garbage_collect()
        # Apply dynamic minimization if the circuit size grows more than twice
        if minimize and (circuit.size() > alpha * curr_size):
            manager.minimize()
            curr_size = circuit.size()
        circuit.deref()
    return circuit

def non_incremental_compile(circuit: SddManager, rules: List[Tuple[int, List[Dict[str, List[int]]]]], manager: SddManager, minimize: bool = False, alpha: float = 2) -> SddManager:
    """
    Non-incrementally compile the conjoin operation between a given SDD circuit
    and a set of rules.

    This compilation is performed by following the proposed algorithm in
    "Separating Incremental and Non-Incremental Bottom-Up Compilation", by
    Alexis de Colnet (2023). More specifically, the algorithm follows the
    Proposition 3 (ilustrated in Figure 2) of the paper; but is heavily inspired
    by Lemma 13 of the Appendix A1 of the same paper.

    Parameters
    ----------
    circuit : SddManager
        The existing SDD circuit that will be conjoined with the new rules to
        form each cluster.

    rules : List[Tuple[int, List[Dict[str, List[int]]]]
        A list of tuples where the first element is an integer representing the
        head atom of a rule, and the second element is a list of dictionaries,
        where each dictionary has two keys, `pos_body` and `neg_body`, which are
        lists of integers representing the atoms in the positive and negative
        bodies of the rule, respectively.

    manager : SddManager
        The SDD manager.

    minimize : bool, optional
        Whether to apply dynamic minimization to the SDD (default is False).

    alpha : float, optional
        The threshold that determines when a circuit has grown too much and
        should be minimized (default is 2).

    Returns
    -------
    SddManager
        The updated SDD circuit.
    """
    clusters = []
    # Ref the circuit to avoid it being garbage collected
    circuit.ref()
    # Conjoin each set of head rules with `circuit`
    for head, bodies in rules:
        c = circuit & compile_head(head, bodies, manager)
        c.ref()
        clusters.append(c)
        manager.garbage_collect()
    # Incrementally conjoin the clusters
    # Since `circuit` and this set of rules are disjoint, we can conjoin them
    # with polynomial size guarantees
    non_incremental_circuit = manager.true()
    for c in clusters:
        non_incremental_circuit &= c
        non_incremental_circuit.ref()
        c.deref()
        manager.garbage_collect()

    non_incremental_circuit.ref()
    circuit.deref()
    manager.garbage_collect()

    # TODO: investigate if it is possible to apply dynamic minimization on the
    # circuit after each conjoin operation of the clusters
    if minimize:
        manager.minimize()

    non_incremental_circuit.deref()
    return non_incremental_circuit

def compile_loop(loop:Dict, manager: SddManager) -> SddManager:
    """
    Compile a loop formula into an SDD.

    Parameters
    ----------
    circuit : SddManager
        The existing SDD circuit that will be conjoined with the loop formula.

    loop : Dict
        A dictionary representing the loop formula, which contains the keys
        `pos_body` and `neg_body`, which are lists of integers representing the
        atoms in the positive and negative bodies of the rule, respectively.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The updated SDD circuit.
    """
    # Compile the disjunction of the atoms inside the loop
    loop_disjunction = manager.false()
    for atom in loop["loop"]:
        loop_disjunction |= manager.literal(atom)

    r_l_disjunction = manager.false()
    for r_l in loop["r_l"]:
        r_l_circuit = manager.true()
        for lit in r_l:
            r_l_circuit &= manager.literal(lit)
        r_l_disjunction |= r_l_circuit
    # Compute the implication between the loop disjunction and the R_L
    return ~loop_disjunction | r_l_disjunction

def compile_loop_formulas(circuit: SddManager, loop_formulas: List[Dict], manager: SddManager, minimize: bool = False, alpha: float = 2) -> SddManager:
    """
    Compile multiple loop formulas into an SDD.

    Parameters
    ----------
    circuit : SddManager
        The existing SDD circuit that will be conjoined with the loop formulas.

    loop_formulas : List[Dict]
        A list of dictionaries representing the loop formulas, where each
        dictionary contains the keys `pos_body` and `neg_body`, which are lists
        of integers representing the atoms in the positive and negative bodies
        of the rule, respectively.

    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddManager
        The updated SDD circuit.
    """
    curr_size = circuit.size()
    for loop in loop_formulas:
        circuit &= compile_loop(loop, manager)
        circuit.ref()
        manager.garbage_collect()
        if minimize and (circuit.size() > alpha * curr_size):
            manager.minimize()
            curr_size = circuit.size()
        circuit.deref()

    circuit.ref()
    if minimize:
        manager.minimize()
    manager.garbage_collect()
    circuit.deref()

    return circuit

def compile_upper(atoms: List[int], manager: SddManager, count: int, target: int):
    """
    Compute `smaller than` cardinality constraint, which means that the number of
    atoms in the circuit must be equal to or less than the target value.
    """
    # If the current circuit exceeds the target size, return false
    if count > target:
        return manager.false()
    # If there are `k` atoms left and `k + count` does not exceed the target
    # size, then we satisfy the upper bound constraint
    if len(atoms) <= (target - count):
        return manager.true()

    # pop the first atom from the list
    atom = manager.literal(atoms[0])

    # compile the circuits when the atom is true/false, respectively
    if_circuit = compile_upper(atoms[1:], manager, count + 1, target)
    else_circuit = compile_upper(atoms[1:], manager, count, target)

    # return the ITE (if then else) circuit
    return (atom & if_circuit) | (~atom & else_circuit)

def compile_lower(atoms: List[int], manager: SddManager, count: int, target: int):
    """
    Compute `greater than` cardinality constraint.

    Note that, if something is not `smaller than target + 1`, then it must be
    `equal to target` or `greater than target`, which is essentially the same as
    `greater than or equal to target`.
    """
    # If we already have the reached the lower bound, return true
    if count >= target:
        return manager.true()
    # If there are `k` atoms left and `k + count` is less than `target`,
    # then we can not satisfy the lower bound constraint even if we have all
    # true atoms from now on, so we return false
    if len(atoms) < (target - count):
        return manager.false()

    # pop the first atom from the list
    atom = manager.literal(atoms[0])

    # compile the circuits when the atom is true/false, respectively
    if_circuit = compile_lower(atoms[1:], manager, count + 1, target)
    else_circuit = compile_lower(atoms[1:], manager, count, target)

    # return the ITE (if then else) circuit
    return (atom & if_circuit) | (~atom & else_circuit)

def compile_exactly_one(atoms: List[int], manager: SddManager):
    """
    Compile an exactly one constraint into an SDD.

    Parameters
    ----------
    atoms : List[int]
        A list of atoms.
    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddNode
        The compiled SDD.
    """
    if not atoms:
        return manager.false()

    atom = manager.atom(atoms[0])
    if_circuit = manager.true()
    for a in atoms[1:]:
        if_circuit &= -manager.literal(-a)
    else_circuit = compile_exactly_one(atoms[1:], manager)
    # return the ITE (if then else) circuit
    return (atom & if_circuit) | (~atom & else_circuit)

def compile_all_but_one(atoms: List[int], manager: SddManager):
    """
    Compile a constraint that exactly one of the given atoms is not true.

    Parameters
    ----------
    atoms : List[int]
        The atoms to compile.
    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddNode
        The compiled SDD.
    """
    if not atoms:
        return manager.false()

    atom = manager.atom(atoms[0])
    if_circuit = compile_all_but_one(atoms[1:], manager)
    else_circuit = manager.true()
    for a in atoms[1:]:
        else_circuit &= manager.literal(a)
    # return the ITE (if then else) circuit
    return (atom & if_circuit) | (~atom & else_circuit)

def compile_constraint(constraint: Dict, manager: SddManager):
    """
    Compile a constraint into an SDD.

    Parameters
    ----------
    constraint : Dict[str, List[int]]
        A dictionary representing a constraint.
    manager : SddManager
        The SDD manager.

    Returns
    -------
    SddNode
        The compiled SDD.
    """
    constraint_type = constraint.get("type")
    atoms = constraint.get("atoms", [])
    target = constraint.get("target", 0)

    constraint_handlers = {
        "upper": lambda: compile_upper(atoms, manager, 0, target),
        "lower": lambda: compile_lower(atoms, manager, 0, target),
        "exactly": lambda: compile_upper(atoms, manager, 0, target) & compile_lower(atoms, manager, 0, target),
        "exactly_one": lambda: compile_exactly_one(atoms, manager),
        "all_but_one": lambda: compile_all_but_one(atoms, manager)
    }

    if constraint_type in constraint_handlers:
        return constraint_handlers[constraint_type]()
    # If no valid constraint type is specified, return error
    raise ValueError("Invalid constraint type")

def compile_cardinality_constraints(circuit: SddManager, constraints: List[Dict[str, List[int]]],
    manager: SddManager):
    circuit.ref()
    cardinality_circuit = manager.true()
    for constraint in constraints:
        cardinality_circuit &= compile_constraint(constraint, manager)
        cardinality_circuit.ref()
        manager.garbage_collect()
    circuit &= cardinality_circuit
    cardinality_circuit.deref()
    manager.garbage_collect()
    circuit.deref()
    return circuit

def compile(rules: Dict[int, List[Dict[str, List[int]]]], prob_ids: List[int], facts: List[int],
        num_atoms: int, grouped_rules: List[List[int]],
        loop_formulas: List[Dict[str, List]], vtree_type: str = "balanced",
        minimize: bool = False, alpha: float = 2, X_deterministic: bool = False
    ) -> Tuple[Dict, SddManager, Vtree]:
    """
    Compile an SDD from given rules and save the results.

    Parameters
    ----------
    rules : Dict[int, List[Dict[str, List[int]]]]
        A dictionary where the keys are integers representing the ID of the head
        atom of a set of rules, and the values are lists of dictionaries, where
        each dictionary has two keys, `pos_body` and `neg_body`, which are lists
        of integers representing the atoms in the positive and negative bodies
        of the rule, respectively.
    prob_ids : List[int]
        A list of integers representing the IDs of atoms that are present in
        probabilistic facts or annotated disjunctions.
    num_atoms : int
        The total number of atoms in the program.
    grouped_rules : List[List[int]]
        A list of lists of integers, where each list represents the head atoms
        of the rules in each group. This list should be ordered by the size
        (number of elements) of the groups, from largest to smallest.
    vtree_type : str, optional
        Type of vtree to use for the SDD compilation (default is "balanced").
    minimize : bool, optional
        Whether to apply dynamic minimization to the SDD (default is False).
    X_deterministic : bool, optional
        Whether to compute is_X_var based on probabilistic facts and annotated
        disjunctions (default is False).
    alpha : float, optional
        The threshold that determines when a circuit has grown too much and
        should be minimized (default is 2).

    Returns
    -------
    result : Dict
        A dictionary containing the compilation statistics, including
        compilation time, circuit size (nodes and edges), model count, and
        compression rate.
    circuit : SddManager
        The compiled SDD circuit.
    vtree : Vtree
        The vtree used for the SDD compilation.
    """

    is_X_var = None
    if X_deterministic:
        # is_X_var is a list of booleans of size num_atoms + 1,
        # where the i-th element is True if atom i is a probabilistic fact or an
        #  annotated disjunction (belongs to the X set), and False otherwise.
        is_X_var = [False] * (num_atoms + 1)
        for pid in prob_ids:
            is_X_var[pid] = True

    vtree = Vtree(var_count=num_atoms, vtree_type=vtree_type, is_X_var=is_X_var)
    manager = SddManager(vtree=vtree)

    start_time = time.time()
    # First, compile facts
    #circuit = compile_facts(facts, manager)
    # First, startup the circuit
    circuit = manager.true()

    # Second, compile the first group of head rules into an SDD
    queue = deque(grouped_rules)
    head_ids = queue.popleft()
    compilation_rules = [(head, rules[head]) for head in head_ids]
    circuit = compile_program(circuit, compilation_rules, manager, minimize=minimize, alpha=alpha)

    # If there are more groups (non-incremental compilation), we compile by clustering
    while queue:
        heads_ids = queue.popleft()
        compilation_rules = [(head, rules[head]) for head in heads_ids]
        circuit = non_incremental_compile(circuit, compilation_rules, manager, minimize=minimize)

    # Finally, we compile the loop formulas to throw away models due to positive cycles
    circuit = compile_loop_formulas(circuit, loop_formulas, manager, minimize=minimize, alpha=alpha)

    end_time = time.time()
    wmc = circuit.wmc(log_mode=False).propagate()

    result = {
        "compilation_time": end_time - start_time,
        "circuit_size": {
            "nodes": circuit.count(),
            "edges": circuit.size()
        },
        "model_count": wmc,
        "compression_rate": wmc / circuit.size()
    }

    return result, circuit, vtree

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compile SDD from given JSON file.")
    parser.add_argument("json_file", type=str, help="Path to the input JSON file")
    parser.add_argument("vtree_type", type=str, nargs='?', default="b", choices=["b", "r"], help="Type of vtree (b: balanced, r: right)")
    parser.add_argument("--minimize", action="store_true", help="Apply dynamic minimization to the SDD")
    parser.add_argument("--X_deterministic", action="store_true", help="If True, compute is_X_var based on probabilistic facts and annotated disjunctions")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
    parser.add_argument("--alpha", type=float, default=2, help="Threshold to apply dynamic minimization")
    args = parser.parse_args()

    fp = args.json_file
    file_name = os.path.splitext(fp)[0]

    vtree_dict = {"b": "balanced", "r": "right"}
    vtree_type = vtree_dict.get(args.vtree_type, "balanced")
    minimize = args.minimize
    X_deterministic = args.X_deterministic
    alpha = args.alpha


    with open(fp, "r") as file:
            data = json.load(file)

    rules, prob_ids, facts_ids, num_atoms, grouped_rules, loop_formulas = parse_input(data)

    def _compile_wrapper():
        return compile(rules, prob_ids, facts_ids, num_atoms, grouped_rules,
                       loop_formulas, vtree_type=vtree_type,
                       minimize=minimize, alpha=alpha,
                       X_deterministic=X_deterministic)

    mem_result = memory_usage(( _compile_wrapper, ), max_usage=True, retval=True)
    max_memory_usage, (result, circuit, vtree) = mem_result
    result["max_memory_usage_mb"] = max_memory_usage

    if args.verbose:
        print(json.dumps(result, indent=4))

    curr_dir = os.path.dirname(fp)
    prefix = os.path.splitext(os.path.basename(fp))[0]

    stats_dir = os.path.join(curr_dir, "stats")
    os.makedirs(stats_dir, exist_ok=True)
    sdd_dir = os.path.join(curr_dir, "sdd")
    os.makedirs(sdd_dir, exist_ok=True)
    vtree_dir = os.path.join(curr_dir, "vtree")
    os.makedirs(vtree_dir, exist_ok=True)

    # File name already includes if we used initialization heuristic or
    # non-incremental compilation
    prefix += f"_{vtree_type}"
    prefix += "_Xdet" if args.X_deterministic else ""
    prefix += "_min" if args.minimize else ""

    with open(f"{stats_dir}/{prefix}_stats.json", "w") as f:
            json.dump(result, f, indent=4)
    circuit.save_as_dot(f"{sdd_dir}/{prefix}_sdd.dot".encode())
    vtree.save_as_dot(f"{vtree_dir}/{prefix}_vtree.dot".encode())

    circuit.save(f"{sdd_dir}/{prefix}.sdd".encode())
    vtree.save(f"{vtree_dir}/{prefix}.vtree".encode())
