import argparse
import json
import random
import uuid
from copy import deepcopy
from datetime import datetime
from multiprocessing import Pool
from pathlib import Path

import numpy as np
import pyscipopt as scip
from tqdm import tqdm

from MPE.utils.utils import FileIO, GraphicalModel, ScipUtils

from .branching_rules import StrongBranchingCollector

CWD = Path.cwd()
DATASET_DIR = CWD / "MPE" / "dataset"


def encode_dict(dict_):
    return ",".join([f"{k}={v}" for k, v in dict_.items()])


def decode_dict(dict_):
    l = []
    for item in dict_.split(","):
        var, val = item.split("=")
        l.append((int(var), int(val)))
    return l


def encode_list(list_):
    return ",".join(map(str, list_))


def decode_list(string):
    return list(map(int, string.split(",")))


class DepthDataGenerator:
    ACCEPTED_STATUS = {"optimal", "bestsollimit", "timelimit"}
    OPTIMAL_STATUS = {"optimal", "bestsollimit"}

    def __init__(self, pgm, network_name, samples):
        self.pgm = pgm
        self.all_vars = list(self.pgm.graph.nodes)
        self.samples = samples
        self.network_name = network_name
        self.network_dir = Path(DATASET_DIR) / network_name
        self.save_dir = self.network_dir / "l2c-dataset"
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.max_time_per_sample = 1800

    def _random_float_range(self, min_val, max_val):
        return (max_val - min_val) * random.random() + min_val

    def _init_record(
        self, sample_id, qr, num_query_vars, num_evid_vars, data_collection_strategy
    ):
        self.record = {
            "props": {
                "sample_id": sample_id,
                "network_name": self.network_name,
                "sample": encode_list(self.samples[sample_id]),
                "qr": qr,
                "num_query_vars": num_query_vars,
                "num_evid_vars": num_evid_vars,
                "data_collection_strategy": data_collection_strategy,
                "max_time_per_sample": self.max_time_per_sample,
                "choice_cands_strategy": "strong_branching",
            },
        }
        if data_collection_strategy == "sequential":
            self.record["depths"] = []
        return

    def _get_root_branch_cands(self, evidence, extra_params):
        program, _, _ = ScipUtils.get_mpe_program(self.pgm, evidence)
        program.setParams(extra_params)
        sbr = StrongBranchingCollector(program)
        program.includeBranchrule(
            sbr,
            "sbr",
            "Branching rule to log root decisions",
            priority=10000000,
            maxdepth=-1,
            maxbounddist=1,
        )
        program.optimize()
        sorted_scores = sorted(
            sbr.choices_score.items(), reverse=True, key=lambda x: x[1]
        )
        root_branch_cands = list(set([cand[0][0] for cand in sorted_scores]))
        return root_branch_cands

    def _solve_mpe_program(self, evidence, objective_limit=None, extra_params=None):
        program, x_variables, z_variables = ScipUtils.get_mpe_program(
            self.pgm, evidence
        )
        if objective_limit:
            program.setObjlimit(objective_limit - 1e-5)
            program.setParam("limits/bestsol", 1)

        if extra_params:
            program.setParams(extra_params)

        program.optimize()
        stats = ScipUtils.get_statistics(program)
        assignments = [0] * len(self.all_vars)
        for var in self.all_vars:
            assignments[var] = round(1 - program.getVal(x_variables[var][0]))
        return assignments, stats

    def _get_choices_var_val(self, root_assignments, root_branch_cands):
        choices_var_val = []
        for cand in root_branch_cands:
            choices_var_val.append((cand, root_assignments[cand]))
        random.shuffle(choices_var_val)
        return choices_var_val

    def get_child_score(self, root_stats, parent_stats, child_stats):
        parser = lambda x: (root_stats[x], parent_stats[x], child_stats[x])
        root_status, parent_status, child_status = parser("status")
        root_nodes, parent_nodes, child_nodes = parser("num_nodes")
        root_objective, parent_objective, child_objective = parser("objective_value")

        if parent_status in self.OPTIMAL_STATUS:
            if child_status in self.OPTIMAL_STATUS:
                assert round(child_objective, 3) >= round(root_objective, 3)
                if child_nodes < parent_nodes:
                    return 2  # In this case, only the nodes can be improved
                else:
                    return -1  # Declining Nodes
            else:
                return -2  # Declining status

        elif parent_status == "timelimit":
            if child_status in self.OPTIMAL_STATUS:
                return 3  # Improving status
            elif child_status == "timelimit":
                if child_objective > parent_objective:
                    return 1  # Improving Objective
                elif round(child_objective, 4) == round(parent_objective, 4):
                    if child_nodes < parent_nodes:
                        return 1  # Improving Nodes
                    else:
                        return -1  # Declining Nodes
                else:
                    return -1  # Declining Objective
            else:
                return -1  # Unknown status

        return -1

    def select_variable(self, choice_records, root_status, k=3, strategy="greedy"):
        if root_status in self.ACCEPTED_STATUS:
            key = "num_nodes"
            func = min
        else:
            key = "objective_value"
            func = max

        strategy_to_behaviour = {
            "greedy": {"score": 1, "index": 0},
            "random": {"score": 1, "index": None},
            "bad": {"score": 1, "index": -1},
            "worst": {"score": None, "index": -1},
        }

        iterable = []
        assert len(choice_records) > 0
        for idx, r in enumerate(choice_records):
            score_val = strategy_to_behaviour[strategy]["score"]
            if score_val is None or (score_val and r["assigned_score"] == score_val):
                iterable.append((idx, r["stats"][key]))

        optimal_items = list(map(lambda x: x[0], sorted(iterable, key=lambda x: x[1])))

        index = strategy_to_behaviour[strategy]["index"]
        if strategy == "random":
            index = np.random.choice(range(len(optimal_items)), size=1).item()

        return optimal_items[index]

    def generate_dataset_sequential(self, sample_id):
        start_time = datetime.now()
        self._init_record(sample_id, "sequential")
        variable_strategy = np.random.choice(
            list(self.variable_strategies.keys()),
            size=1,
            p=list(self.variable_strategies.values()),
        ).item()

        evid_vars = sorted(
            np.random.choice(
                np.array(self.all_vars), size=self.num_evid_vars, replace=False
            ).tolist()
        )
        query_vars = list(set(self.all_vars) - set(evid_vars))
        curr_sample = self.samples[sample_id]

        self.record["variable_choosing_strategy"] = variable_strategy
        self.record["evid_vars"] = encode_list(evid_vars)
        self.record["query_vars"] = encode_list(query_vars)

        get_evidence = lambda evid_vars: {var: curr_sample[var] for var in evid_vars}

        base_evidence = get_evidence(evid_vars)
        uninterrupted_assignments, uninterrupted_stats, _ = self._solve_mpe_program(
            base_evidence
        )
        if uninterrupted_stats["status"] != "optimal":
            print("Could not solve the problem optimally")
            return

        root_objective = uninterrupted_stats["objective_value"]
        self.record["optimal_objective"] = root_objective

        root_assignments, root_stats, root_branch_cands = self._solve_mpe_program(
            base_evidence, objective_limit=None
        )
        root_status = root_stats["status"]
        self.record["root_status"] = root_status
        num_nodes_root = root_stats["num_nodes"]
        num_nodes_least = num_nodes_root
        self.record["root_record"] = {
            "evidence": encode_dict(base_evidence),
            "assignments": encode_list(root_assignments),
            "root_stats": root_stats,
            "uninterrupted_stats": uninterrupted_stats,
            "uninterrupted_assignments": encode_list(uninterrupted_assignments),
        }
        if root_status not in self.ACCEPTED_STATUS:
            print("Root status not acceptable", root_status)
            return

        optimal_assignments = deepcopy(root_assignments)
        choice_list = []
        curr_parent_stats = deepcopy(root_stats)

        # Once we have collected these many number of better choices and worse choice
        # we will break and move onto next depth
        threshold_ct = int(0.1 * self.num_query_vars)
        better_ct, bad_ct = threshold_ct, threshold_ct
        max_choices_allowed = 2 * threshold_ct
        chosen_var, chosen_value = None, None
        for curr_depth in range(self.max_depth):
            depth_record = {
                "depth_stats": curr_parent_stats,
            }

            choices_var_val = self._get_choices_var_val(
                optimal_assignments, root_branch_cands
            )

            choice_records = []
            depth_record["total_choices"] = len(choices_var_val)
            choices_tried, better_choices_ct, bad_choices_ct = 0, 0, 0
            for choice_variable, choice_val in choices_var_val[:max_choices_allowed]:
                if better_choices_ct >= better_ct and bad_choices_ct >= bad_ct:
                    break

                evidence = deepcopy(base_evidence)
                evidence[choice_variable] = choice_val
                assignments, child_stats, branch_cands = self._solve_mpe_program(
                    evidence, objective_limit=None
                )
                assert chosen_var not in branch_cands
                assigned_score = self.get_child_score(
                    root_stats, curr_parent_stats, child_stats
                )
                choices_tried += 1
                choice_records.append(
                    {
                        "variable": choice_variable,
                        "value": choice_val,
                        "evidence": encode_dict(evidence),
                        "assignments": encode_list(assignments),
                        "stats": child_stats,
                        "root_branch_cands": encode_list(branch_cands),
                        "assigned_score": assigned_score,
                    }
                )
                if assigned_score == 1:
                    better_choices_ct += 1
                else:
                    bad_choices_ct += 1

            depth_record["choices_tried"] = choices_tried
            depth_record["better_choices_ct"] = better_choices_ct
            depth_record["bad_choices_ct"] = bad_choices_ct
            depth_record["choices"] = choice_records
            if better_choices_ct < 1:
                depth_record["chosen"] = None
                self.record["depths"].append(depth_record)
                break
            else:
                new_choice_idx = self.select_variable(
                    choice_records, root_status, strategy=variable_strategy
                )
                chosen_var = choice_records[new_choice_idx]["variable"]
                chosen_value = choice_records[new_choice_idx]["value"]
                num_nodes_least = choice_records[new_choice_idx]["stats"]["num_nodes"]
                base_evidence[chosen_var] = chosen_value
                child_branch_cands = choice_records[new_choice_idx]["root_branch_cands"]
                if child_branch_cands:
                    root_branch_cands = decode_list(child_branch_cands)
                else:
                    root_branch_cands.remove(chosen_var)
                choice_list.append((chosen_var, chosen_value))
                curr_parent_stats = deepcopy(choice_records[new_choice_idx]["stats"])
                depth_record["chosen"] = {"variable": chosen_var, "value": chosen_value}
                self.record["depths"].append(depth_record)

        end_time = datetime.now()
        self.record["choice_list"] = encode_list(choice_list)
        self.record["max_depth_reached"] = curr_depth
        self.record["num_nodes_least"] = num_nodes_least
        self.record["num_nodes_root"] = num_nodes_root
        self.record["start_time"] = start_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["time_required"] = (end_time - start_time).seconds
        with open(self.save_dir / f"{str(uuid.uuid4())}.json", "w") as f:
            json.dump(self.record, f)
        return

    def generate_dataset_ns(self, sample_id):
        extra_params = {
            "presolving/maxrounds": 0,
            "separating/maxcuts": 0,
            "separating/maxcutsroot": 0,
            "limits/time": 60,
        }
        start_time = datetime.now()

        qr = round(self._random_float_range(0.25, 0.75), 2)
        num_vars = len(self.all_vars)
        num_query_vars = round(num_vars * qr)
        num_evid_vars = num_vars - num_query_vars

        self._init_record(
            sample_id, qr, num_query_vars, num_evid_vars, "non_sequential"
        )

        evid_vars = sorted(
            np.random.choice(
                np.array(self.all_vars), size=num_evid_vars, replace=False
            ).tolist()
        )
        query_vars = list(set(self.all_vars) - set(evid_vars))
        curr_sample = self.samples[sample_id]

        self.record["evid_vars"] = encode_list(evid_vars)
        self.record["query_vars"] = encode_list(query_vars)

        get_evidence = lambda evid_vars: {var: curr_sample[var] for var in evid_vars}

        base_evidence = get_evidence(evid_vars)
        root_branch_cands = self._get_root_branch_cands(base_evidence, extra_params)
        if not root_branch_cands:
            print("No root branch candidates found")
            return

        uninterrupted_assignments, uninterrupted_stats = self._solve_mpe_program(
            base_evidence, extra_params={"limits/time": 400}
        )
        if uninterrupted_stats["status"] != "optimal":
            print("Could not solve the problem optimally")
            return

        root_objective = uninterrupted_stats["objective_value"]
        self.record["optimal_objective"] = root_objective

        root_assignments, root_stats = self._solve_mpe_program(
            base_evidence, objective_limit=None, extra_params=extra_params
        )
        root_status = root_stats["status"]
        self.record["root_status"] = root_status
        num_nodes_root = root_stats["num_nodes"]
        num_nodes_least = num_nodes_root
        self.record["root_record"] = {
            "evidence": encode_dict(base_evidence),
            "assignments": encode_list(root_assignments),
            "root_stats": root_stats,
            "uninterrupted_stats": uninterrupted_stats,
            "uninterrupted_assignments": encode_list(uninterrupted_assignments),
        }
        if root_status not in self.ACCEPTED_STATUS:
            print("Root status not acceptable", root_status)
            return

        optimal_assignments = deepcopy(uninterrupted_assignments)
        curr_parent_stats = deepcopy(root_stats)

        # Once we have collected these many number of better choices and worse choice
        # we will break and move onto next depth
        threshold_ct = int(0.1 * num_query_vars)
        better_ct, bad_ct = threshold_ct, threshold_ct
        max_choices_allowed = 3 * threshold_ct
        try:
            depth_record = {
                "depth_stats": curr_parent_stats,
            }

            choices_var_val = self._get_choices_var_val(
                optimal_assignments, root_branch_cands
            )

            choice_records = []
            depth_record["total_choices"] = len(choices_var_val)
            choices_tried, better_choices_ct, bad_choices_ct = 0, 0, 0
            for choice_variable, choice_val in choices_var_val[:max_choices_allowed]:
                if better_choices_ct >= better_ct and bad_choices_ct >= bad_ct:
                    break

                time_now = datetime.now()
                if (time_now - start_time).seconds > self.max_time_per_sample:
                    break

                evidence = deepcopy(base_evidence)
                evidence[choice_variable] = choice_val
                assignments, child_stats = self._solve_mpe_program(
                    evidence, objective_limit=None, extra_params=extra_params
                )
                assigned_score = self.get_child_score(
                    root_stats, curr_parent_stats, child_stats
                )
                choices_tried += 1
                choice_records.append(
                    {
                        "variable": choice_variable,
                        "value": choice_val,
                        "evidence": encode_dict(evidence),
                        "assignments": encode_list(assignments),
                        "stats": child_stats,
                        "assigned_score": assigned_score,
                    }
                )
                if assigned_score > 0:
                    better_choices_ct += 1
                else:
                    bad_choices_ct += 1
                if child_stats["num_nodes"] < num_nodes_least:
                    num_nodes_least = child_stats["num_nodes"]

            depth_record["choices_tried"] = choices_tried
            depth_record["better_choices_ct"] = better_choices_ct
            depth_record["bad_choices_ct"] = bad_choices_ct
            depth_record["choices"] = choice_records
            self.record["depths_record"] = depth_record
        except Exception as err:
            print(err)
            return

        end_time = datetime.now()
        self.record["num_nodes_least"] = num_nodes_least
        self.record["num_nodes_root"] = num_nodes_root
        self.record["start_time"] = start_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["time_required"] = (end_time - start_time).seconds
        save_file_name = f"{str(uuid.uuid4())}.json"
        with open(self.save_dir / save_file_name, "w") as f:
            json.dump(self.record, f)
        return

    def optimal_vs_non_optimal(self, sample_id):
        start_time = datetime.now()

        qr = round(self._random_float_range(0.25, 0.75), 2)
        num_vars = len(self.all_vars)
        num_query_vars = round(num_vars * qr)
        num_evid_vars = num_vars - num_query_vars

        self._init_record(
            sample_id, qr, num_query_vars, num_evid_vars, "optimal_vs_non_optimal"
        )

        evid_vars = sorted(
            np.random.choice(
                np.array(self.all_vars), size=num_evid_vars, replace=False
            ).tolist()
        )
        query_vars = list(set(self.all_vars) - set(evid_vars))
        curr_sample = self.samples[sample_id]

        self.record["evid_vars"] = encode_list(evid_vars)
        self.record["query_vars"] = encode_list(query_vars)

        get_evidence = lambda evid_vars: {var: curr_sample[var] for var in evid_vars}

        base_evidence = get_evidence(evid_vars)
        uninterrupted_assignments, uninterrupted_stats = self._solve_mpe_program(
            base_evidence
        )
        if uninterrupted_stats["status"] != "optimal":
            print("Could not solve the problem optimally")
            return

        root_objective = uninterrupted_stats["objective_value"]

        end_time = datetime.now()
        self.record["root_record"] = {
            "evidence": encode_dict(base_evidence),
            "uninterrupted_stats": uninterrupted_stats,
            "uninterrupted_assignments": encode_list(uninterrupted_assignments),
        }
        self.record["optimal_objective"] = root_objective
        self.record["start_time"] = start_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S")
        self.record["time_required"] = (end_time - start_time).seconds
        save_file_name = f"{str(uuid.uuid4())}.json"
        with open(self.save_dir / save_file_name, "w") as f:
            json.dump(self.record, f)
        return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate Imitation Learning from Strong branching scores"
    )
    parser.add_argument("--network-name", required=True, dest="network_name")
    parser.add_argument("--num-workers", dest="num_workers", default=4, type=int)
    parser.add_argument("--num-samples", dest="num_samples", default=1000, type=int)
    parser.add_argument(
        "--data-strategy", dest="data_strategy", type=str, default="ranking"
    )
    args = parser.parse_args()

    network_dir = Path(DATASET_DIR / args.network_name)
    pgm = GraphicalModel.from_uai_file(network_dir / "pgm-model.uai")

    samples = FileIO.read_samples(
        args.network_name, sample_file="train.txt", delimiter=" "
    )
    sample_ids_to_process = np.random.choice(
        list(samples.keys()), size=args.num_samples, replace=False
    ).tolist()

    generator = DepthDataGenerator(pgm, args.network_name, samples)
    if args.data_strategy == "ranking":
        print("Generating the Non-sequential data")
        func = generator.generate_dataset_ns
        tasks_per_child = 1
    elif args.data_strategy == "optimality":
        print("Generating the optimal vs non optimal data")
        func = generator.optimal_vs_non_optimal
        tasks_per_child = 4
    else:
        raise ValueError("Unrecognized Strategy")

    print(f"Processing: {len(sample_ids_to_process)} samples")
    with Pool(processes=args.num_workers, maxtasksperchild=tasks_per_child) as pool:
        r = list(
            tqdm(
                pool.imap(func, sample_ids_to_process),
                total=len(sample_ids_to_process),
            )
        )
