import json
import pathlib
import subprocess
from itertools import product
from multiprocessing import Manager

import torch

from MPE.train.models import BBAttentionNetwork
from MPE.utils.utils import ScipUtils

from .scip_extensions import CheckpointHandler

CWD = pathlib.Path.cwd()
DATASET_DIR = CWD / "MPE" / "dataset"
TIMELIMITS = [10, 30, 60]


def repeat(n_times):
    def decorator(func):
        def wrapper(self, *args, **kwargs):
            results = []
            for _ in range(n_times):
                assignments, stats, checkpoints = func(self, *args, **kwargs)
                results.append(
                    {
                        "assignments": self.encode_list(assignments),
                        "stats": stats,
                        "checkpoints": checkpoints,
                    }
                )
            return results

        return wrapper

    return decorator


class Runner:

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts=False,
    ):
        self.pgm = pgm
        self.uai_file_path = uai_file_path
        self.all_vars = list(self.pgm.graph.nodes)
        self.samples = samples
        self.sample_to_evid_var = sample_to_evid_var
        self.disable_cuts = disable_cuts

        self.results = Manager().list()
        self.qr = qr
        self.num_vars = len(self.all_vars)
        self.num_query_vars = round(self.num_vars * self.qr)
        self.num_evid_vars = self.num_vars - self.num_query_vars

    def encode_dict(self, dict_):
        return ",".join([f"{k}={v}" for k, v in dict_.items()])

    def encode_list(self, list_):
        return ",".join(map(str, list_))

    def decode_list(self, string, delimiter=","):
        return list(map(eval, string.split(delimiter)))

    def get_scip_results(
        self, artifacts_dir, runner, config_name, depth, width, timelimit=10
    ):
        results_file = (
            artifacts_dir
            / "timelimit-expts"
            / "scip"
            / runner
            / config_name
            / f"timelimit-{timelimit}"
            / f"depth-{depth}-width-{width}"
            / "results.json"
        )
        if results_file.exists():
            with open(results_file) as f:
                data = json.load(f)

            return {d["sample_id"]: d for d in data}
        return None

    @repeat(1)
    def solve_mpe_program(
        self, evidence, branching_rule=None, heuristic=None, nodesel=None
    ):
        program, x_variables, z_variables = ScipUtils.get_mpe_program(
            self.pgm, evidence, disable_cuts=self.disable_cuts
        )
        program.setParam("limits/time", max(TIMELIMITS))

        checkpoint_handler = CheckpointHandler(program, checkpoint_times=TIMELIMITS)
        program.includeEventhdlr(
            checkpoint_handler,
            "Checkpoint Handler",
            "Saves best solution at specific times",
        )

        if nodesel:
            nodesel.scip = program
            program.includeNodesel(
                nodesel,
                "NeuralNodeSel",
                "Probabilistic Depth First Search Nodesel",
                1000000,
                1000000,
            )

        if branching_rule:
            branching_rule.scip = program
            program.includeBranchrule(
                branching_rule,
                "NeuralBranchingRule",
                "Optimality/Decimation based branching rule",
                priority=10000000,
                maxdepth=-1,
                maxbounddist=1,
            )

        program.optimize()
        stats = ScipUtils.get_statistics(program)

        assignments = [0] * len(self.all_vars)
        for var in self.all_vars:
            assignments[var] = round(1 - program.getVal(x_variables[var][0]))
        return assignments, stats, checkpoint_handler.checkpoints

    def solve_mpe_program_daoopt(self, daoopt_dir, evidence, sample_id):
        evid_file_content = [str(len(evidence))]
        for var, value in evidence.items():
            evid_file_content.extend([str(var), str(value)])
        evid_file_content = " ".join(evid_file_content)

        evid_file_path = daoopt_dir / (str(sample_id) + ".uai.evid")
        with open(evid_file_path, "w") as f:
            f.write(evid_file_content + "\n")

        log_out_path = CWD / daoopt_dir / (str(sample_id) + ".out")
        subprocess.run(
            f"./daoopt -f {CWD / self.uai_file_path} -e {str(CWD / evid_file_path)} > {log_out_path}",
            cwd=CWD,
            shell=True,
        )

    def solve(self, sample_id):
        raise NotImplementedError


class NeuralApi:

    def __init__(self, network_dir, all_vars):
        self.model_path = (
            network_dir / "training-artifacts" / "l2c" / "trained-models" / "best.pt"
        )
        self.var_values_to_idx = {
            (var, val): idx
            for idx, (var, val) in enumerate(product(all_vars, range(2)))
        }
        self.idx_to_var_values = {v: k for k, v in self.var_values_to_idx.items()}
        self.model = self.load_model()

    def load_model(self):
        device = "cuda"
        num_variable_value = len(self.var_values_to_idx)
        model = BBAttentionNetwork(
            num_variable_value + 1,
            embed_dim=256,
            num_layers=15,
            hidden_dim=512,
            output_size=1,
            padding_idx=num_variable_value,
        ).to(device)
        model.eval()
        model.load_state_dict(torch.load(self.model_path, weights_only=True))
        return model

    def get_batch(self, batch_evidence, batch_query_vars):
        device = "cuda"
        evid_idxs, query_idxs, all_choices = [], [], []
        for be, bqv in zip(batch_evidence, batch_query_vars):
            evid_idxs.append([self.var_values_to_idx[item] for item in be.items()])
            query_idxs.append(
                [self.var_values_to_idx[item] for item in product(bqv, range(2))]
            )
            all_choices.append(list(range(len(self.var_values_to_idx))))

        return {
            "evidence": torch.tensor(evid_idxs, dtype=torch.int).to(device),
            "choices": torch.tensor(query_idxs, dtype=torch.int).to(device),
            "all_choices": torch.tensor(all_choices, dtype=torch.int).to(device),
        }

    def predict(self, batch_evidence, batch_query_vars):
        batch = self.get_batch(batch_evidence, batch_query_vars)
        choices = batch["choices"]
        with torch.no_grad():
            optimality_pred, decimation_scores = self.model(**batch)
            optimality_prob = torch.nn.functional.sigmoid(optimality_pred)
            copred = optimality_pred.gather(1, choices.type(torch.int64))
            coprob = optimality_prob.gather(1, choices.type(torch.int64))
            cds = decimation_scores.gather(1, choices.type(torch.int64))
        return copred, coprob, cds
