import time

import torch
from pyscipopt import SCIP_EVENTTYPE, SCIP_RESULT, Branchrule, Eventhdlr, Nodesel

from MPE.utils.utils import ScipUtils


class CheckpointHandler(Eventhdlr):
    def __init__(self, model, checkpoint_times):
        """
        Event handler that checkpoints the best solution at specified time intervals.
        :param checkpoint_times: List of times (in seconds) to save checkpoints.
        """
        self.checkpoint_times = sorted(checkpoint_times)
        self.start_time = None
        self.model = model
        self.checkpoints = {}

    def eventinit(self):
        """Register the event handlers."""
        self.model.catchEvent(SCIP_EVENTTYPE.BESTSOLFOUND, self)
        self.model.catchEvent(SCIP_EVENTTYPE.NODEFOCUSED, self)

    def eventexec(self, event):
        """Executes when either 'nodefocus' (for time check) or 'newsolution' occurs."""
        if self.start_time is None:
            self.start_time = time.time()

        elapsed_time = time.time() - self.start_time

        # Check if we have reached a new checkpoint time
        for t in self.checkpoint_times:
            if elapsed_time > t:
                continue
            self.checkpoints[t] = ScipUtils.get_statistics(self.model)


class NeuralBranchingRule(Branchrule):

    def __init__(self):
        self.scip = None
        self.branching_scores = None

    def set_branching_scores(self, scores):
        self.branching_scores = {}
        for item, score in scores.items():
            var, val = item
            self.branching_scores[f"t_x_{var}_{val}"] = score
        return

    def branchexeclp(self, allowaddcons):
        (branch_cands, *_, npriocands, _) = self.scip.getLPBranchCands()

        bscores = {}
        for i in range(npriocands):
            cand = branch_cands[i]
            if cand.name in self.branching_scores:
                bscores[i] = self.branching_scores[cand.name]

        if len(bscores) > 0:
            best_cand_idx = max(bscores.items(), key=lambda x: x[1])[0]
            down_child, eq_child, up_child = self.model.branchVar(
                branch_cands[best_cand_idx]
            )
            return {"result": SCIP_RESULT.BRANCHED}
        else:
            return {"result": SCIP_RESULT.DIDNOTRUN}

    def branchexecps(self, allowaddcons):
        return {"result": SCIP_RESULT.DIDNOTRUN}


class NeuralNodeSel(Nodesel):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.scip = None
        self.domains = None

    def set_domains(self, optimality_score, query_vars):
        self.domains = {}
        assignments = torch.argmax(
            optimality_score[0].view(-1, 2).cpu(), dim=1
        ).tolist()
        for var, val in zip(query_vars, assignments):
            self.domains[f"t_x_{var}_{val}"] = 1
            self.domains[f"t_x_{var}_{1-val}"] = 0
        return

    def nodeselect(self):
        """Decide which of the leaves from the branching tree to process next"""
        selnode = self.scip.getBestNode()
        return {"selnode": selnode}

    def nodecomp(self, node1, node2):
        """
        compare two leaves of the current branching tree

        It should return the following values:

        value < 0, if node 1 comes before (is better than) node 2
        value = 0, if both nodes are equally good
        value > 0, if node 1 comes after (is worse than) node 2.
        """
        depth_1 = node1.getDepth()
        depth_2 = node2.getDepth()
        if depth_1 > depth_2:
            return -1
        elif depth_1 < depth_2:
            return 1
        else:
            bdchg1 = node1.getDomchg().getBoundchgs()[0]
            bdchg2 = node2.getDomchg().getBoundchgs()[0]
            node1_var, node2_var = bdchg1.getVar(), bdchg2.getVar()
            bd_type1, bd_type2 = bdchg1.getBoundtype(), bdchg2.getBoundtype()
            assert node1_var.name == node2_var.name

            if "z" in node1_var.name:
                return 0

            var_val = self.domains[node1_var.name]
            if var_val == 1:
                if bd_type1 == 0:
                    return -1
                elif bd_type2 == 0:
                    return 1
                else:
                    return 0
            elif var_val == 0:
                if bd_type1 == 1:
                    return -1
                elif bd_type2 == 1:
                    return 1
                else:
                    return 0
            return 0
