import argparse
import random
import re
from collections import defaultdict
from itertools import product
from pathlib import Path

import networkx as nx
import numpy as np
from pyscipopt import Model
from tqdm import tqdm

CWD = Path.cwd()
DATASET_DIR = CWD / "MPE" / "dataset"
SCIP_PARAMS = {"limits/time": 60, "display/verblevel": 0, "limits/memory": 20000}
DISABLE_CUTS_PARAMS = {
    "presolving/maxrounds": 0,
    "separating/maxcuts": 0,
    "separating/maxcutsroot": 0,
}


class FileIO:

    @staticmethod
    def read_samples(network_name, sample_file="samples.txt", delimiter=" "):
        sample_file = DATASET_DIR / network_name / sample_file
        assert sample_file.exists()
        with open(sample_file) as f:
            sample_lines = list(filter(None, f.read().split(" \n")))

        samples = {
            id: list(map(int, line.split(delimiter)))
            for id, line in enumerate(sample_lines)
        }
        return samples


class ScipUtils:

    @staticmethod
    def get_variable_by_name(program, name):
        for variable in program.getVars():
            if variable.name == name:
                return variable
        raise ValueError(f"Variable with name: {name} not found")

    @staticmethod
    def get_constraint_by_name(program, name):
        for constraint in program.getConss():
            if constraint.name == name:
                return constraint
        raise ValueError(f"Constraint with name: {name} not found")

    @staticmethod
    def get_statistics(program):
        best_sol = program.getBestSol()
        best_obj = program.getSolObjVal(best_sol)
        return {
            "status": program.getStatus(),
            "num_nodes": program.getNTotalNodes(),
            "objective_value": program.getObjVal(),
            "best_objective_value": best_obj,
            "solving_time": program.getSolvingTime(),
            "primal_bound": program.getPrimalbound(),
            "dual_bound": program.getDualbound(),
            "gap": program.getGap(),
            "num_solutions_found": program.getNSolsFound(),
        }

    @staticmethod
    def get_mpe_program(pgm, evidence=None, disable_cuts=False):
        nodes = pgm.graph.nodes
        cliques = pgm.parameters.keys()
        model = Model(createscip=True)

        x_variables = defaultdict(dict)
        for node in nodes:
            # Add a variable for the variable with both values
            # The variable x=0 and x=1 will be used accordingly later to add the constraint on the table variables
            x_variables[node][0] = model.addVar(f"x_{node}_0", vtype="BINARY")
            x_variables[node][1] = model.addVar(f"x_{node}_1", vtype="BINARY")
            # Add a constraint that only one of the variables can be true
            model.addCons(x_variables[node][0] + x_variables[node][1] == 1)

        z_variables = defaultdict(dict)
        objective = 0
        for clique_idx, clique in enumerate(cliques):
            clique_parameters = pgm.parameters[clique]
            clique_val_combs = product(range(2), repeat=len(clique))
            for val_idx, val_comb in enumerate(clique_val_combs):
                log_prob = np.log(clique_parameters[val_comb])
                z_variables[clique_idx][val_idx] = {
                    "var": model.addVar(f"z_{clique_idx}_{val_idx}", vtype="BINARY")
                }
                objective += z_variables[clique_idx][val_idx]["var"] * log_prob

        for clique_idx, clique in enumerate(cliques):
            clique_parameters = pgm.parameters[clique]
            clique_val_combs = product(range(2), repeat=len(clique))
            for val_idx, val_comb in enumerate(clique_val_combs):
                z_var = z_variables[clique_idx][val_idx]["var"]
                variable_expressions, variable_document = [], []
                for node_idx, node_val in enumerate(val_comb):
                    x_var = x_variables[clique[node_idx]][node_val]
                    variable_expressions.append(x_var)
                    variable_document.append(f"{x_var.name}")

                z_variables[clique_idx][val_idx]["expression"] = variable_expressions
                model.addConsAnd(
                    variable_expressions,
                    z_var,
                    name=f"Set {z_var.name} true for {', '.join(variable_document)}",
                )

        if evidence:
            for var, value in evidence.items():
                model.addCons(
                    x_variables[var][value] == 1,
                    name=f"Set evidence x_{var}_{value}==1",
                )

        model.setParams(SCIP_PARAMS)
        if disable_cuts:
            model.setParams(DISABLE_CUTS_PARAMS)
        model.setObjective(objective, sense="maximize")
        return model, x_variables, z_variables


class GraphicalModel:

    def __init__(self, graph, parameters, type, domains, num_cliques, cliques=None):
        self.type = type
        self.graph = graph
        self.parameters = parameters
        self.num_cliques = num_cliques
        self.cliques = cliques
        self.domains = domains

    @staticmethod
    def read_uai_file(uai_file):
        with open(uai_file, "r") as f:
            data = re.split("\n{2}", f.read())
        data = list(filter(lambda x: not x.startswith("c"), data))
        preamble, tables = data[0].split("\n"), list(filter(None, data[1:]))
        return preamble, tables

    @staticmethod
    def _get_factor(table, domains):
        parsed_table = re.split("\n\s|\s", table)
        parsed_table = list(filter(None, parsed_table))
        table_size, table_rows = parsed_table[0], parsed_table[1:]

        # Initialize nd array with number of dimensions equal to number of variables
        # And the size of each dimension equal to the domain size of variable
        # i.e. If the factor has 3 binary variables then factor is of shape (2,2,2)
        factor = np.zeros(tuple(domains))
        num_bits = len(domains)
        for idx, val in enumerate(table_rows):
            bin_idx = list(map(lambda x: int(x), f"{idx:0{num_bits}b}"))
            factor[tuple(bin_idx)] = float(val)

        # If any table entry is zero, assign small probability and renormalize for
        # numerical stability
        if np.any(factor == 0.0):
            factor += 1e-6
            factor /= np.sum(factor, axis=-1)
        return factor

    @classmethod
    def from_uai_file(cls, uai_file_path):
        preamble, tables = cls.read_uai_file(uai_file_path)
        type = preamble[0]
        graph = nx.Graph()

        num_variables = int(preamble[1])
        for i in range(num_variables):
            graph.add_node(i)

        cardinalities = preamble[2].split()
        domains = np.array(list(map(lambda x: int(x), cardinalities)))
        for domain in domains:
            assert domain == 2, "Parser only implemented for binary data"

        num_cliques = int(preamble[3])
        cliques = preamble[4 : 4 + num_cliques]

        num_tables = len(tables)
        if num_tables != num_cliques:
            # Table parsing error retry with another regex
            str_tables = "\n".join(tables)
            unlinked_tables = list(filter(None, str_tables.split("\n")))
            tables = []
            if len(unlinked_tables) == 2 * num_cliques:
                for i in range(num_cliques):
                    table_domain = unlinked_tables[2 * i]
                    table_probs = unlinked_tables[2 * i + 1]
                    tables.append("\n".join((table_domain, table_probs)))
            elif len(unlinked_tables) == num_cliques:
                for i in range(num_cliques):
                    split_table_def = unlinked_tables[i].split()
                    table_domain = split_table_def[0]
                    table_probs = " ".join(split_table_def[1:])
                    tables.append("\n".join((table_domain, table_probs)))

        num_tables = len(tables)
        assert num_tables == num_cliques
        vars_to_cliques = defaultdict(list)
        parameters = {}
        for i, clique in enumerate(cliques):
            clique_info = re.split("[\s\t]", clique)
            clique_info = list(filter(None, clique_info))
            num_vars_clique, variables_clique = clique_info[0], list(
                map(lambda x: int(x), clique_info[1:])
            )

            parameters[tuple(variables_clique)] = cls._get_factor(
                tables[i], domains[variables_clique]
            )
            for j in range(int(num_vars_clique)):
                curr_var = variables_clique[j]
                vars_to_cliques[curr_var].append(tuple(variables_clique))
                neighboring_variables = (
                    variables_clique[0:j] + variables_clique[j + 1 :]
                )
                for neigh_var in neighboring_variables:
                    graph.add_edge(int(curr_var), int(neigh_var))

        return cls(graph, parameters, type, domains, num_cliques, vars_to_cliques)

    def save_model(self, file_path):
        def list_of_strs(list_):
            return list(map(str, list_))

        with open(file_path, "w") as f:
            f.write(f"{self.type}\n")
            f.write(f"{len(self.graph.nodes)}\n")
            f.write(f"{' '.join(list_of_strs(self.domains))}\n")
            f.write(f"{self.num_cliques}\n")

            keys, values = [], []
            for k, v in self.parameters.items():
                keys.append(k)
                values.append(v)

            for c in keys:
                f.write(" ".join(list_of_strs([len(c)] + list(c))) + "\n")

            f.write("\n")
            for i, p in enumerate(values):
                p_flat = p.flatten()
                assert len(p_flat) == 2 ** len(keys[i])
                f.write(f"{len(p_flat)}\n")
                f.write(f"{' '.join(list_of_strs(p_flat))}")
                f.write("\n\n")


class GibbsSampler:

    def __init__(
        self, graphical_model: GraphicalModel, burn_in: int = 1000, seed: int = 42
    ):
        np.random.seed(seed)
        self.graphical_model = graphical_model
        self.parameters = graphical_model.parameters
        self.num_vars = len(self.graphical_model.graph.nodes)
        self.burn_in = burn_in
        self.epsilon = 1e-6
        self.init_state = self._get_initial_sample()
        self.evidence_vars = None

    def run_chain(self, sample_count, curr_state=None, evidence=None):
        if curr_state is None:
            curr_state = np.copy(self.init_state)

        if evidence:
            for var, value in evidence.items():
                curr_state[var] = value

        for _ in range(sample_count):
            curr_state = self.get_new_state(curr_state)
        return curr_state

    def get_prob_var_given_neigbors(self, new_state, var):
        domain_size_var = self.graphical_model.domains[var]
        var_cliques = self.graphical_model.cliques[var]
        prob_var_given_neigbors = np.ones(domain_size_var)
        for clique in var_cliques:
            clique_idx = tuple(
                [
                    (
                        new_state[clique_variable]
                        if clique_variable != var
                        else slice(None)
                    )
                    for clique_variable in clique
                ]
            )
            clique_factor = self.parameters[clique]
            conditional_factor = clique_factor[clique_idx]
            if np.any(clique_factor[clique_idx] == 0):
                conditional_factor += self.epsilon
                conditional_factor /= np.sum(conditional_factor)

            prob_var_given_neigbors *= conditional_factor
        prob_var_given_neigbors /= np.sum(prob_var_given_neigbors)
        return prob_var_given_neigbors

    def _get_new_value(self, new_state, var):
        domain_size_var = self.graphical_model.domains[var]
        prob_var_given_neighbors = self.get_prob_var_given_neigbors(new_state, var)
        new_value = np.random.choice(
            range(domain_size_var), 1, p=prob_var_given_neighbors
        )[0]
        return new_value

    def get_new_state(self, curr_state):
        new_state = np.copy(curr_state)
        for var in range(len(curr_state)):
            if not self.evidence_vars or (
                self.evidence_vars and var not in self.evidence_vars
            ):
                new_value_var = self._get_new_value(new_state, var)
                new_state[var] = new_value_var
        return new_state

    def _get_initial_sample(self):
        return np.random.randint(0, 2, size=self.num_vars)

    def sample_batch(self, dataset_size):
        new_state = np.copy(self.init_state)

        # Burn In
        pbar = tqdm(total=self.burn_in)
        for _ in range(self.burn_in):
            new_state = self.run_chain(1, new_state)
            pbar.update(1)
        pbar.close()

        num_samples = 0
        pbar = tqdm(total=dataset_size)
        while num_samples < dataset_size:
            new_state = self.run_chain(1, new_state)
            num_samples += 1
            pbar.update(1)
            yield new_state
        pbar.close()

    def save_samples_to_file(self, dataset_size, save_path, append=False):
        samples_dedupe = set()
        if append:
            with open(save_path) as f:
                lines = f.read().split(" \n")
            for line in lines:
                if line:
                    samples_dedupe.add(line)

        with open(save_path, "a") as f:
            for sample in self.sample_batch(dataset_size):
                str_sample = " ".join(map(str, sample.tolist()))
                if str_sample not in samples_dedupe:
                    samples_dedupe.add(str_sample)
                    f.write(f"{str_sample} \n")
        return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parse the given uai file")
    parser.add_argument("--uaifile", required=False)
    parser.add_argument("--network-name", required=True, dest="network_name")
    parser.add_argument("--test-type", dest="test_type")
    parser.add_argument("--qr", type=float)
    parser.add_argument("--burn-in", default=1000, dest="burn_in", type=int)
    parser.add_argument("--sample-size", dest="sample_size", type=int, default=0)
    args = parser.parse_args()
    dataset_dir = Path(DATASET_DIR)

    network_dir = Path(dataset_dir / args.network_name)
    network_dir.mkdir(exist_ok=True)
    uai_file_path = network_dir / "pgm-model.uai"

    # Read the UAI file
    pgm = GraphicalModel.from_uai_file(uai_file_path)
    pgm.save_model(network_dir / "pgm-model-no-zeroes.uai")

    if args.test_type == "samples":
        # Save Gibbs sample to file at the network location
        if args.sample_size > 0:
            append = False
            save_path = network_dir / "train.txt"
            if save_path.exists():
                append = True

            GibbsSampler(pgm, burn_in=args.burn_in).save_samples_to_file(
                args.sample_size, save_path, append=append
            )
    elif args.test_type == "scip":
        er = 1 - args.qr
        evid_vars = random.sample(
            list(pgm.graph.nodes), k=int(er * len(pgm.graph.nodes))
        )
        samples = FileIO.read_samples(args.network_name, sample_file="train.txt")
        sample = random.choice(samples)
        evidence = {var: int(sample[var]) for var in evid_vars}
        program, x_vars, z_vars = ScipUtils.get_mpe_program(
            pgm, evidence, disable_cuts=True
        )
        program.setParam("display/verblevel", 4)
        program.setParam("limits/time", 3600)
        program.setParam("heuristics/simplerounding/priority", 1000000)
        program.optimize()
    else:
        raise ValueError("Unrecognized")
