import torch
import math
from ortools.linear_solver import pywraplp
from typing import List, Tuple
import sys

from models.idx2word import Idx2Word
from preprocessing.utilities import create_atom

# The following provides the linear inequalities corresponding to the abductive proofs
# Let \phi_1, \dots, \phi_n be the number of proofs
# Each proof is a conjunction of facts, \phi_l = \f_{l,1}, \dots, \f_{l,n_l}
# The Tseitin transformation results in the following DNF formula
# \bigwedge_{l=1}^n \alpha_l (*)
# \wedge
# \bigwedge_{l=1}^n \alpha_l <-> \f_{l,1}, \dots, \f_{l,n_l}
# The second conjunction results in the following clauses:
# (Type I)
# \alpha_l -> \f_{l,1}, \dots, \f_{l,n_l} = \neg \alpha_l \vee (\f_{l,1}, \dots, \f_{l,n_l}) =
# (\neg \alpha_l \vee \f_{l,1} ) \wedge \dots \wedge (\neg \alpha_l \vee \f_{l,n_l} ) (**)
# (Type II)
# \f_{l,1}, \dots, \f_{l,n_l} -> \alpha_l =
# \neg \f_{l,1} \vee \dots \vee \neg \f_{l,n} \vee \alpha_l (***)

# Based on the above, we have the following linear equations:
# Inequality I: a_1 + \dots a_n \geq 1 <= computed out of (*)
# For each clause in (**), we have the following linear inequality:
# Inequality II: 1- \alpha_l + \f_{l,j} \geq 1
# For (***), we have the following linear inequality:
# Inequality III: 1 - \f_{l,1} + \dots + 1 - \f_{l,n} + \alpha_l >= 1

# Let \phi_1, \dots, \phi_n be the number of proofs
# Each proof is a conjunction of facts, \phi_l = \f_{l,1}, \dots, \f_{l,n_l}
# If the answer to the query is false, then this is equivalent to having the formula
# \neg \phi_1 \wedge \neg \phi_2 \wedge \dots \wedge \phi_n
# where \neg \phi_l = neg \f_{l,1} \vee \dots \vee \neg \f_{l,n_l}
# So, in that case, the formula is already in CNF form, and hence, the Tseitin transformation is not required.
# For each clause neg \f_{l,1} \vee \dots \vee \neg \f_{l,n_l},
# the linear program will include the inequality: 1 - \f_{l,1} + \dots + 1 - \f_{l,n_l} >= 1.

# We additionally, have the following linear inequalities to ensure that only one label each assigned to one
# for each input object x_i
# Inequality IV: q_{i,1} + \dots + q_{i,m} = 1, where q_{i,j} denotes that x_i is assigned to label j.

# In addition, we need to add linear inequaldities for the objective, and the distribution constraints,
# where each distribution constraint makes sure that the distribution of a class should be as in the emp_dist.

# The code assumes that we have two neural predicates: name and rela.
# proofs_n is a list of the candidate proofs corresponding to each input training sample.
# idx2word holds the list of possible names, attributes and relations.

# Implements the relaxation via https://developers.google.com/optimization/mip/mip_example#comparing_linear_and_integer_optimization
# If emp_dist == None, then we do not introduce distribution constraints to the lp.

def create_atom_helper(predicate, c, box):
    if predicate == "name":
        return create_atom(predicate, [c, box])
    elif predicate == "rela":
        return create_atom(predicate, [c, *box])
    
def ilp_pywrap(
    proofs_n: List[List[List[Tuple]]],
    predictions,
    idx2word: Idx2Word,
    distribution=None,
    epsilon_ilp=0.99,
    continuous_relaxation=False,
):
    epsilon = epsilon_ilp
    names_c = idx2word.get_names()
    relas_c = idx2word.get_relas()
    # n denotes the number of input training samples
    n = len(proofs_n)

    while True:
        # Create the mip solver with the SCIP backend.
        if not continuous_relaxation:
            solver = pywraplp.Solver.CreateSolver("SCIP")
        else:
            solver = pywraplp.Solver.CreateSolver("GLOP")
        if not solver:
            return

        # Each item of the list holds a dictionary that maps each atom in the proofs
        # for the item to a unique variable of the solver.
        if not continuous_relaxation:
            sample_to_vars = {
                s: {
                    create_atom_helper(predicate, c, box): solver.IntVar(
                        0, 1, create_atom_helper(predicate, c, box)
                    )
                    for predicate, classes in zip(
                        ["name", "rela"], [names_c, relas_c]
                    )
                    for box in predictions[s][predicate]
                    for c in classes
                }
                for s in range(n)
            }
        else:
            sample_to_vars = {
                s: {
                    create_atom_helper(predicate, c, box): solver.IntVar(
                        0, 1, create_atom_helper(predicate, c, box)
                    )
                    for predicate, classes in zip(
                        ["name", "rela"], [names_c, relas_c]
                    )
                    for box in predictions[s][predicate]
                    for c in classes
                }
                for s in range(n)
            }

        # "a_{s,l}" is the variable associated to the s-th training sample, l-th proof
        if not continuous_relaxation:
            alphas = [
                [solver.IntVar(0, 1, f"a{s}_{l}") for l in range(len(proofs_n[s]))]
                for s in range(n)
            ]
        else:
            alphas = [
                [solver.NumVar(0, 1, f"a{s}_{l}") for l in range(len(proofs_n[s]))]
                for s in range(n)
            ]

        # s in an index of training samples
        for s in range(n):
            v_s = sample_to_vars[s]
            alpha_s = alphas[s]
            # i is an index of the proofs of the s-th training sample
            for l in range(len(proofs_n[s])):
                # The facts that are true in the l-th proof
                # Each fact is mapped to variables of the solver
                f_l = proofs_n[s][l]

                # For each fact add inequality II: 1- \alpha_l + \f_{l,j} \geq 1
                [solver.Add(1 - alpha_s[l] + v_s[f_l[j]] >= 1) for j in range(len(f_l))]

                # Add inequality III: 1 - \f_{l,1} + \dots + 1 - \f_{l,n} + \alpha_i >= 1
                solver.Add(
                    sum([1 - v_s[f_l[j]] for j in range(len(f_l))]) + alpha_s[l] >= 1
                )

            # For the s-th training sample, add inequality I: a_1 + \dots a_n \geq 1 that varies over all the proofs for the s-th sample
            solver.Add(sum([alpha_s[l] for l in range(len(proofs_n[s]))]) >= 1)

        # For each input object x_j add
        # Inequality IV: x_{j,1} + \dots + x_{j,m} = 1.
        # These are the classification constraints
        for s in range(n):
            v_s = sample_to_vars[s]
            [
                solver.Add(
                    sum([v_s[create_atom_helper(predicate, c, box)] for c in classes]) == 1
                )
                for predicate, classes in zip(["name", "rela"], [names_c, relas_c])
                for box in predictions[s][predicate]
            ]

        # Create the distribution constraints
        if distribution is not None:
            name_facts = sum(
                [
                    len(predictions[s]["name"])
                    for s in range(n)
                ]
            )
            rela_facts = sum(
                [
                    len(predictions[s]["rela"])
                    for s in range(n)
                ]
            )

            # Enforce the distribution constraints
            for predicate, classes, num in zip(
                ["name", "rela"],
                [names_c, relas_c],
                [name_facts, rela_facts],
            ):
                for c in classes:
                    if predicate == "name":
                        idx = idx2word.name_to_idx(c)
                    if predicate == "rela":
                        idx = idx2word.rela_to_idx(c)
                    constraint = [
                        sample_to_vars[s][create_atom_helper(predicate, c, box)]
                        for s in range(n)
                        for box in predictions[s][predicate]
                    ]

                    solver.Add(
                        sum(constraint) >= distribution[predicate][idx].item() * num - epsilon
                    )
                    solver.Add(
                        sum(constraint) <= distribution[predicate][idx].item() * num + epsilon
                    )

        # Create the objective expression
        objective_expr = [
            sample_to_vars[s][create_atom_helper(predicate, cls_name, box)]
            * math.log(predictions[s][predicate][box][c].item())
            for predicate, classes in zip(["name", "rela"], [names_c, relas_c])
            for s in range(n)
            for box in predictions[s][predicate]
            for c, cls_name in enumerate(classes)
            if predictions[s][predicate][box][c] > 0
        ]

        solver.Minimize(-solver.Sum(objective_expr))

        # Prints the program in the console
        # print(
        #    solver.ExportModelAsLpFormat(False).replace("\\", "").replace(",_", ","),
        #    sep="\n",
        # )

        #file = open("lp.txt", "w")
        #file.write(solver.ExportModelAsLpFormat(False).replace("\\", "").replace(",_", ","))
        #file.close()

        print(f"Solving with {solver.SolverVersion()}")
        status = solver.Solve()

        if status == pywraplp.Solver.OPTIMAL:
            print("Solution:")
            print("Objective value =", solver.Objective().Value())
            print(f"Problem solved in {solver.wall_time():d} milliseconds")

            Q = [
                {
                    predicate: {
                        box: torch.Tensor(
                            [
                                sample_to_vars[s][
                                    create_atom_helper(predicate, c, box)
                                ].solution_value()
                                for c in classes
                            ]
                        )#.cuda()
                        for box in predictions[s][predicate]
                    }
                    for predicate, classes in zip(
                        ["name", "rela"], [names_c, relas_c]
                    )
                }
                for s in range(n)
            ]

            # Return a map from atoms to their pseudolabels
            return Q

        else:
            print(
                "The problem does not have an optimal solution. Increasing epsilon to {}".format(
                    epsilon + 1
                )
            )
            epsilon = epsilon + 1

