import argparse
import json
import pathlib
from copy import deepcopy
from datetime import datetime
from itertools import product
from math import ceil
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import torch
import torch.multiprocessing as mp
from tqdm import tqdm

from MPE.data_collection.branching_rules import StrongBranchingCollector
from MPE.utils.utils import FileIO, GibbsSampler, GraphicalModel, ScipUtils

from .base import NeuralApi, Runner
from .scip_extensions import NeuralBranchingRule, NeuralNodeSel

CWD = pathlib.Path.cwd()
DATASET_DIR = CWD / "MPE" / "dataset"
TIMELIMITS = [10, 30, 60]


class ScipRunner(Runner):

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts,
        artifacts_dir,
        **kwargs,
    ):
        super().__init__(
            pgm,
            uai_file_path,
            samples,
            sample_to_evid_var,
            qr,
            disable_cuts,
        )
        self.daoopt_results_dir = (
            artifacts_dir
            / "timelimit-expts"
            / "daoopt"
            / "scip"
            / "scip"
            / f"depth-{0}-width-{0}"
        )
        self.daoopt_results_dir.mkdir(parents=True, exist_ok=True)

    def solve(self, sample_id):
        start_time = datetime.now()
        evid_vars = self.sample_to_evid_var[sample_id]
        curr_sample = self.samples[sample_id]
        evidence = {var: curr_sample[var] for var in evid_vars}
        scip_results = self.solve_mpe_program(evidence)
        end_time = datetime.now()
        self.results.append(
            {
                "sample_id": sample_id,
                "evidence": self.encode_dict(evidence),
                "scip_results": scip_results,
                "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
                "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
                "time_elapsed": (end_time - start_time).seconds,
            }
        )
        return

    def solve_daoopt(self, sample_id):
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        self.solve_mpe_program_daoopt(self.daoopt_results_dir, base_evidence, sample_id)
        return


class NeuralRunner(Runner):

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts,
        depth,
        depth_percent,
        beam_width,
        config_name,
        artifacts_dir,
        **kwargs,
    ):
        super().__init__(
            pgm,
            uai_file_path,
            samples,
            sample_to_evid_var,
            qr,
            disable_cuts,
        )
        self.config_name = config_name
        self.depth = depth
        self.depth_percent = depth_percent
        self.beam_width = beam_width
        self.neural_api = NeuralApi(network_dir, self.all_vars)
        self.daoopt_results_dir = (
            artifacts_dir
            / "timelimit-expts"
            / "daoopt"
            / "neural"
            / config_name
            / f"depth-{self.depth_percent}-width-{self.beam_width}"
        )
        self.daoopt_results_dir.mkdir(parents=True, exist_ok=True)
        self.scip_results = self.get_scip_results(
            artifacts_dir, "neural", config_name, depth_percent, self.beam_width
        )

    def solve(self, sample_id):
        start_time = datetime.now()
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        beam = [[]]
        for step in range(self.depth):
            batch_evidence, batch_query_vars = [], []
            for seq in beam:
                evidence = deepcopy(base_evidence)
                evid_vars = deepcopy(base_evid_vars)
                for var, val in seq:
                    evid_vars.append(var)
                    evidence[var] = val

                batch_evidence.append(evidence)
                batch_query_vars.append(list(set(self.all_vars) - set(evid_vars)))

            optimality_pred, optimality_prob, decimation_scores = (
                self.neural_api.predict(batch_evidence, batch_query_vars)
            )

            all_candidates = []
            for b_idx in range(len(optimality_prob)):
                qvs = torch.tensor(batch_query_vars[b_idx])
                cop_r = optimality_pred[b_idx].view(-1, 2)
                if self.config_name == "l2c_opt":
                    fqvs = qvs
                    fscores, fqvalues = torch.max(cop_r, dim=1)
                elif self.config_name == "l2c_rank":
                    cds_r = decimation_scores[b_idx].view(-1, 2)
                    diff_op_prob = torch.abs(cop_r[:, 0] - cop_r[:, 1])
                    threshold = torch.sort(diff_op_prob, descending=True)[0][
                        int(0.1 * diff_op_prob.size(0))
                    ].item()
                    same_dvs = torch.argmax(cop_r, dim=1) == torch.argmax(cds_r, dim=1)
                    filter_conditions = torch.logical_and(
                        same_dvs, diff_op_prob >= threshold
                    )
                    fqvs = qvs[filter_conditions.cpu()]
                    fscores, fqvalues = torch.max(
                        cds_r[filter_conditions],
                        dim=1,
                    )

                fqvs = fqvs.numpy().tolist()
                fscores = fscores.cpu().numpy().tolist()
                fqvalues = fqvalues.cpu().numpy().tolist()

                curr_seq = beam[b_idx]
                for i, var in enumerate(fqvs):
                    value = fqvalues[i]
                    candidate = curr_seq + [(var, value)]
                    all_candidates.append((candidate, fscores[i]))

            if len(all_candidates) == 0:
                break

            sorted_top_k = sorted(all_candidates, key=lambda x: x[1], reverse=True)[
                0 : self.beam_width
            ]
            beam = [candidate[0] for candidate in sorted_top_k]

        decision_end_time = datetime.now()

        best_sequence = beam[0]
        for var, val in best_sequence:
            base_evid_vars.append(var)
            base_evidence[var] = val

        query_vars = list(set(self.all_vars) - set(base_evid_vars))

        # Add the var and value to evidence and solve the mpe program
        scip_results = self.solve_mpe_program(base_evidence)
        end_time = datetime.now()
        self.results.append(
            {
                "sample_id": sample_id,
                "evid_vars": self.encode_list(base_evid_vars),
                "query_vars": self.encode_list(query_vars),
                "choice": self.encode_list(best_sequence),
                "evidence": self.encode_dict(base_evidence),
                "scip_results": scip_results,
                "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
                "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
                "decision_time": (decision_end_time - start_time).seconds,
                "time_elapsed": (end_time - start_time).seconds,
            }
        )
        return

    def solve_daoopt(self, sample_id):
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        scip_result = self.scip_results.get(sample_id, None)
        if scip_result is None:
            return

        choices_made = eval(scip_result["choice"])
        for var, val in choices_made:
            base_evid_vars.append(var)
            base_evidence[var] = val

        self.solve_mpe_program_daoopt(self.daoopt_results_dir, base_evidence, sample_id)
        return


class GraphHeuristicRunner(Runner):

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts,
        depth,
        depth_percent,
        artifacts_dir,
        **kwargs,
    ):
        super().__init__(
            pgm,
            uai_file_path,
            samples,
            sample_to_evid_var,
            qr,
            disable_cuts,
        )
        self.var_values_to_idx = {
            (var, val): idx
            for idx, (var, val) in enumerate(product(self.all_vars, range(2)))
        }
        self.idx_to_var_values = {v: k for k, v in self.var_values_to_idx.items()}
        self.depth = depth
        self.var_to_degrees = self._get_var_degrees(pgm)
        self.gibbs_chain = GibbsSampler(pgm)
        self.initial_state = self.gibbs_chain.run_chain(100)
        self.daoopt_results_dir = (
            artifacts_dir
            / "timelimit-expts"
            / "daoopt"
            / "graph"
            / "graph"
            / f"depth-{depth_percent}-width-{1}"
        )
        self.daoopt_results_dir.mkdir(parents=True, exist_ok=True)
        self.scip_results = self.get_scip_results(
            artifacts_dir, "graph", "graph", depth_percent, 1
        )

    def _get_var_degrees(self, pgm):
        degrees = {}
        graph = pgm.graph
        for node in graph.nodes:
            degrees[node] = graph.degree(node)
        return degrees

    def solve(self, sample_id):
        start_time = datetime.now()
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        self.gibbs_chain.evidence_vars = base_evid_vars
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        query_vars = list(set(self.all_vars) - set(base_evid_vars))

        # Get the heuristic sequence, variables chosen using max degree
        # Value sampled from gibbs chain given evidence
        state = self.gibbs_chain.run_chain(
            50, curr_state=self.initial_state, evidence=base_evidence
        )
        query_var_degrees = [(var, self.var_to_degrees[var]) for var in query_vars]
        sorted_degrees = sorted(query_var_degrees, key=lambda x: x[1], reverse=True)
        heuristic_sequence = [
            (var, state[var]) for var, _ in sorted_degrees[: self.depth]
        ]
        decision_end_time = datetime.now()

        for var, val in heuristic_sequence:
            base_evid_vars.append(var)
            base_evidence[var] = val

        scip_results = self.solve_mpe_program(base_evidence)
        end_time = datetime.now()
        self.results.append(
            {
                "sample_id": sample_id,
                "evid_vars": self.encode_list(base_evid_vars),
                "query_vars": self.encode_list(query_vars),
                "choice": self.encode_list(heuristic_sequence),
                "evidence": self.encode_dict(base_evidence),
                "scip_results": scip_results,
                "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
                "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
                "decision_time": (decision_end_time - start_time).seconds,
                "time_elapsed": (end_time - start_time).seconds,
            }
        )
        return

    def solve_daoopt(self, sample_id):
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        scip_result = self.scip_results.get(sample_id, None)
        if scip_result is None:
            return

        choices_made = eval(scip_result["choice"])
        for var, val in choices_made:
            base_evid_vars.append(var)
            base_evidence[var] = val

        self.solve_mpe_program_daoopt(self.daoopt_results_dir, base_evidence, sample_id)
        return


class StrongBranchingRunner(Runner):

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts,
        depth,
        depth_percent,
        beam_width,
        config_name,
        artifacts_dir,
        **kwargs,
    ):
        super().__init__(
            pgm,
            uai_file_path,
            samples,
            sample_to_evid_var,
            qr,
            disable_cuts,
        )
        self.depth = depth
        self.beam_width = beam_width
        self.daoopt_results_dir = (
            artifacts_dir
            / "timelimit-expts"
            / "daoopt"
            / "strong_branch"
            / config_name
            / f"depth-{depth_percent}-width-{self.beam_width}"
        )
        self.daoopt_results_dir.mkdir(parents=True, exist_ok=True)
        self.scip_results = self.get_scip_results(
            artifacts_dir, "strong_branch", "strong_branch", depth_percent, beam_width
        )
        # Strong Branching is taking too much time to execute set a timelimit
        self.time_for_single_decision_in_sec = 30

    def get_strong_branching_score(self, evidence):
        program, _, _ = ScipUtils.get_mpe_program(self.pgm, evidence, self.disable_cuts)
        branching_rule = StrongBranchingCollector(program)
        program.includeBranchrule(
            branching_rule,
            "StrongBranchingCollector",
            "Collect strong branching scores",
            priority=1000000,
            maxdepth=-1,
            maxbounddist=1,
        )
        program.optimize()
        return branching_rule.choices_score, branching_rule.fallback_method

    def solve(self, sample_id):
        start_time = datetime.now()
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        fb_count = 0

        beam = [[]]
        sorted_top_k = []
        curr_time = datetime.now()
        for step in range(self.depth):
            all_candidates = []
            for seq in beam:
                evidence = deepcopy(base_evidence)
                evid_vars = deepcopy(base_evid_vars)
                for var, val in seq:
                    evid_vars.append(var)
                    evidence[var] = val

                choice_score_dict, fb_used = self.get_strong_branching_score(evidence)
                if len(choice_score_dict) == 0:
                    continue
                fb_count += int(fb_used)
                for choice, score in choice_score_dict.items():
                    choice_var = choice[0]
                    if choice_var in evidence:
                        continue

                    for s in seq:
                        if s[0] == choice_var:
                            continue

                    candidate = seq + [choice]
                    all_candidates.append((candidate, score))

            if len(all_candidates) == 0:
                break

            sorted_top_k = sorted(all_candidates, key=lambda x: x[1], reverse=True)[
                0 : self.beam_width
            ]
            beam = [candidate[0] for candidate in sorted_top_k]
            if (
                datetime.now() - curr_time
            ).seconds < self.time_for_single_decision_in_sec:
                curr_time = datetime.now()
            else:
                # We immediately exit if the time to make a decision exceeds timelimit.
                return

        decision_end_time = datetime.now()
        best_sequence = []
        if len(sorted_top_k) > 0:
            best_sequence = sorted_top_k[0][0]
            for var, val in best_sequence:
                base_evid_vars.append(var)
                base_evidence[var] = val
        else:
            print(sample_id)

        query_vars = list(set(self.all_vars) - set(base_evid_vars))
        scip_results = self.solve_mpe_program(base_evidence)
        end_time = datetime.now()
        self.results.append(
            {
                "sample_id": sample_id,
                "fb_count": fb_count,
                "evid_vars": self.encode_list(base_evid_vars),
                "query_vars": self.encode_list(query_vars),
                "choice": self.encode_list(best_sequence),
                "evidence": self.encode_dict(base_evidence),
                "scip_results": scip_results,
                "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
                "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
                "decision_time": (decision_end_time - start_time).seconds,
                "time_elapsed": (end_time - start_time).seconds,
            }
        )
        return

    def solve_daoopt(self, sample_id):
        try:
            curr_sample = self.samples[sample_id]
            base_evid_vars = self.sample_to_evid_var[sample_id]
            base_evidence = {var: curr_sample[var] for var in base_evid_vars}
            scip_result = self.scip_results.get(sample_id, None)
            if scip_result is None:
                return

            choices_made = eval(scip_result["choice"])
            for var, val in choices_made:
                base_evid_vars.append(var)
                base_evidence[var] = val

            self.solve_mpe_program_daoopt(
                self.daoopt_results_dir, base_evidence, sample_id
            )
        except:
            print(sample_id)
        return


class NeuralSearchRunner(Runner):

    def __init__(
        self,
        pgm,
        uai_file_path,
        samples,
        sample_to_evid_var,
        qr,
        disable_cuts,
        config_name,
        **kwargs,
    ):
        super().__init__(
            pgm,
            uai_file_path,
            samples,
            sample_to_evid_var,
            qr,
            disable_cuts,
        )
        self.config_name = config_name
        self.neural_api = NeuralApi(network_dir, self.all_vars)
        self.neural_bc = NeuralBranchingRule()
        self.neural_node_sel = NeuralNodeSel()

    def solve(self, sample_id):
        start_time = datetime.now()
        curr_sample = self.samples[sample_id]
        base_evid_vars = self.sample_to_evid_var[sample_id]
        base_evidence = {var: curr_sample[var] for var in base_evid_vars}
        query_vars = list(set(self.all_vars) - set(base_evid_vars))
        optimality_preds, optimality_probs, decimation_scores = self.neural_api.predict(
            [base_evidence], [query_vars]
        )
        self.neural_node_sel.set_domains(optimality_preds, query_vars)

        if self.config_name == "l2c_opt":
            scores = {}
            for idx, item in enumerate(product(query_vars, range(2))):
                scores[item] = optimality_preds[0][idx].item()
            self.neural_bc.set_branching_scores(scores)
        elif self.config_name == "l2c_rank":
            qvs = torch.tensor(query_vars)
            cop_r = optimality_preds[0].view(-1, 2)
            cds_r = decimation_scores[0].view(-1, 2)
            diff_op_prob = torch.abs(cop_r[:, 0] - cop_r[:, 1])
            threshold = torch.sort(diff_op_prob, descending=True)[0][
                int(0.1 * diff_op_prob.size(0))
            ].item()
            same_dvs = torch.argmax(cop_r, dim=1) == torch.argmax(cds_r, dim=1)
            filter_conditions = torch.logical_and(same_dvs, diff_op_prob >= threshold)
            fqvs = qvs[filter_conditions.cpu()].tolist()
            fscores, fqvalues = torch.max(cds_r[filter_conditions], dim=1)

            branching_scores = {}
            for i, var in enumerate(fqvs):
                value = fqvalues[i]
                branching_scores[(var, value.item())] = fscores[i].item()
            self.neural_bc.set_branching_scores(branching_scores)
        else:
            raise ValueError(f"Config: {self.config_name} not recognized")

        scip_results = self.solve_mpe_program(
            base_evidence, nodesel=self.neural_node_sel
        )
        end_time = datetime.now()
        self.results.append(
            {
                "sample_id": sample_id,
                "evid_vars": self.encode_list(base_evid_vars),
                "query_vars": self.encode_list(query_vars),
                "evidence": self.encode_dict(base_evidence),
                "scip_results": scip_results,
                "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
                "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
                "time_elapsed": (end_time - start_time).seconds,
            }
        )
        return


if __name__ == "__main__":
    STR_TO_RUNNER = {
        "scip": ScipRunner,
        "neural_search": NeuralSearchRunner,
        "neural": NeuralRunner,
        "graph": GraphHeuristicRunner,
        "strong_branch": StrongBranchingRunner,
    }

    parser = argparse.ArgumentParser(description="Evaluate Neural Network based MPE")
    parser.add_argument("--network-name", required=True, dest="network_name")
    parser.add_argument("--qr", required=True, dest="qr", type=float)
    parser.add_argument("--num-workers", dest="num_workers", default=4, type=int)
    parser.add_argument(
        "--tasks-per-child", dest="tasks_per_child", default=5, type=int
    )
    parser.add_argument("--num-cases", dest="num_cases", default=1000, type=int)
    parser.add_argument(
        "--runner",
        dest="runner",
        required=True,
        type=str,
        choices=list(STR_TO_RUNNER.keys()),
    )
    parser.add_argument("--out-file", dest="out_file", required=True, type=str)
    parser.add_argument(
        "--run-daoopt", dest="run_daoopt", action="store_true", default=False
    )
    parser.add_argument("--test", dest="if_testing", default=False, action="store_true")
    parser.add_argument(
        "--disable-cuts", dest="disable_cuts", default=True, action="store_true"
    )
    parser.add_argument("--depth", dest="depth", required=False, type=int, default=0)
    parser.add_argument(
        "--beam-width", dest="beam_width", required=False, type=int, default=0
    )
    parser.add_argument(
        "--config-name",
        dest="config_name",
        type=str,
        required=True,
        choices=[
            "l2c_opt",
            "l2c_rank",
            "scip",
            "graph",
            "strong_branch",
        ],
    )
    args = parser.parse_args()
    network_dir = Path(Path(DATASET_DIR) / args.network_name)
    artifacts_dir = network_dir / "testing-artifacts"
    artifacts_dir.mkdir(parents=True, exist_ok=True)
    uai_file_path = network_dir / "pgm-model.uai"
    pgm = GraphicalModel.from_uai_file(uai_file_path)
    no_zeroes_uai_file_path = network_dir / "pgm-model-no-zeroes.uai"
    if not no_zeroes_uai_file_path.exists():
        raise Exception("No zeros file present")

    all_vars = list(pgm.graph.nodes)
    num_vars = len(all_vars)
    num_query_vars = round(num_vars * args.qr)
    num_evid_vars = num_vars - num_query_vars
    depth_count = ceil(args.depth / 100 * num_query_vars)

    samples = FileIO.read_samples(args.network_name, sample_file="test.txt")
    sample_ids_to_process_file = artifacts_dir / "samples-to-process.txt"
    if sample_ids_to_process_file.exists():
        with open(sample_ids_to_process_file, "r") as f:
            sample_and_evid_vars = f.readlines()

        sample_to_evid_var = {}
        for s in sample_and_evid_vars[0 : args.num_cases]:
            tokens = s.split(",")
            sample_to_evid_var[int(tokens[0])] = list(map(int, tokens[1:]))
        sample_ids_to_process = list(sample_to_evid_var.keys())
    else:
        sample_to_evid_var = {}
        sample_ids_to_process = np.random.choice(
            list(samples.keys()), size=args.num_cases, replace=False
        ).tolist()

        with open(sample_ids_to_process_file, "w") as f:
            for sample_id in sample_ids_to_process:
                evid_vars = sorted(
                    np.random.choice(
                        np.array(all_vars), size=num_evid_vars, replace=False
                    ).tolist()
                )
                sample_to_evid_var[sample_id] = evid_vars
                f.write(f"{sample_id},{','.join(map(str, evid_vars))}\n")

    print(f"Processing: {len(sample_ids_to_process)} samples")

    solver = STR_TO_RUNNER[args.runner](
        pgm,
        no_zeroes_uai_file_path,
        samples,
        sample_to_evid_var,
        args.qr,
        args.disable_cuts,
        depth=depth_count,
        depth_percent=args.depth,
        beam_width=args.beam_width,
        config_name=args.config_name,
        artifacts_dir=artifacts_dir,
    )

    if args.run_daoopt:
        method = solver.solve_daoopt
        save_dir = "daoopt"
    else:
        method = solver.solve
        save_dir = "scip"

    if args.runner in ("neural", "neural_search"):
        mp.set_start_method("spawn", force=True)

    if args.if_testing:
        method(sample_id=31)
        print(solver.results)
        exit()

    with Pool(
        processes=args.num_workers,
        maxtasksperchild=args.tasks_per_child,
        initializer=None,
    ) as pool:
        r = list(
            tqdm(
                pool.imap(method, sample_ids_to_process),
                total=len(sample_ids_to_process),
            )
        )

    # Save scip results
    if not args.run_daoopt:
        for tl in TIMELIMITS:
            scip_results_dir = (
                artifacts_dir
                / "timelimit-expts"
                / "scip"
                / args.runner
                / args.config_name
                / f"timelimit-{tl}"
                / f"depth-{args.depth}-width-{args.beam_width}"
            )
            scip_results_dir.mkdir(parents=True, exist_ok=True)
            all_results = solver.results
            tl_results = []
            for r in all_results:
                tlr = {}
                for k, v in r.items():
                    if k == "scip_results":
                        tlr[k] = [{"stats": v[0]["checkpoints"][tl]}]
                    else:
                        tlr[k] = v
                tl_results.append(tlr)
            with open(scip_results_dir / args.out_file, "w") as f:
                json.dump(tl_results, f)
