from pyscipopt import SCIP_RESULT, Branchrule


class LogBranchingRule(Branchrule):

    def __init__(self, scip):
        self.scip = scip
        self.root_branch_cands = []

    def branchexeclp(self, allowaddcons):
        depth = self.model.getDepth()

        # Get the branching candidates. Only consider the number of priority candidates (they are sorted to be first)
        # The implicit integer candidates in general shouldn't be branched on. Unless specified by the user
        # npriocands and ncands are the same (npriocands are variables that have been designated as priorities)
        if depth == 0:
            (
                branch_cands,
                branch_cand_sols,
                branch_cand_fracs,
                ncands,
                npriocands,
                nimplcands,
            ) = self.scip.getLPBranchCands()

            # Only parse the x variable value pairs
            for i in range(npriocands):
                var_val = branch_cands[i].name
                if "x" in var_val:
                    tokens = var_val.split("_")
                    variable = int(tokens[2])
                    if variable not in self.root_branch_cands:
                        self.root_branch_cands.append(variable)

        return {"result": SCIP_RESULT.DIDNOTRUN}

    def branchexecps(self, allowaddcons):
        return {"result": SCIP_RESULT.DIDNOTRUN}


class StrongBranchingCollector(Branchrule):

    def __init__(self, scip):
        self.scip = scip
        self.choices_score = {}
        self.fallback_method = False

    def branchexeclp(self, allowaddcons):
        (
            branch_cands,
            branch_cand_sols,
            branch_cand_fracs,
            ncands,
            npriocands,
            nimplcands,
        ) = self.scip.getLPBranchCands()

        # Initialise placeholder values
        num_nodes = self.scip.getNNodes()
        lpobjval = self.scip.getLPObjVal()
        lperror = False

        scoring_cands = [cand for cand in branch_cands if "x" in cand.name]
        if scoring_cands == []:
            self.fallback_method = True
            scoring_cands = [cand for cand in branch_cands if "z" in cand.name]

        # Start strong branching and iterate over the branching candidates
        self.scip.startStrongbranch()
        for cand in scoring_cands:
            var_val = cand.name
            if "x" not in var_val:
                tokens = var_val.split("_")
                associated_var = "_".join(tokens[1:])
                for cons in self.scip.getConss():
                    if associated_var in cons.name:
                        x_var_names = cons.name.split("true for ")[1].split(",")
                        x_var_vals = []
                        for x_name in x_var_names:
                            x_ts = x_name.split("_")
                            x_var_vals.append((int(x_ts[1]), int(x_ts[2])))
                        break
            else:
                tokens = var_val.split("_")
                x_var_vals = [(int(tokens[2]), int(tokens[3]))]

            # Check the case that the variable has already been strong branched on at this node.
            # This case occurs when events happen in the node that should be handled immediately.
            # When processing the node again (because the event did not remove it), there's no need to duplicate work.
            if self.scip.getVarStrongbranchNode(cand) == num_nodes:
                down, up, downvalid, upvalid, _, lastlpobjval = (
                    self.scip.getVarStrongbranchLast(cand)
                )
                downgain = max([down - lastlpobjval, 0])
                upgain = max([up - lastlpobjval, 0])
                score = self.scip.getBranchScoreMultiple(cand, [downgain, upgain])
                for var_val in x_var_vals:
                    self.choices_score[var_val] = score
                continue

            # Strong branch!
            (
                down,
                up,
                downvalid,
                upvalid,
                downinf,
                upinf,
                downconflict,
                upconflict,
                lperror,
            ) = self.scip.getVarStrongbranch(cand, 200, idempotent=False)

            # In the case of an LP error handle appropriately (for this example we just break the loop)
            if lperror:
                break

            # In the case of both infeasible sub-problems cutoff the node
            if downinf and upinf:
                return {"result": SCIP_RESULT.CUTOFF}

            # Calculate the gains for each up and down node that strong branching explored
            if not downinf and downvalid:
                downgain = max([down - lpobjval, 0])
            else:
                downgain = 0
            if not upinf and upvalid:
                upgain = max([up - lpobjval, 0])
            else:
                upgain = 0

            score = self.scip.getBranchScoreMultiple(cand, [downgain, upgain])
            for var_val in x_var_vals:
                self.choices_score[var_val] = score

        # End strong branching
        self.scip.endStrongbranch()
        self.scip.interruptSolve()
        return {"result": SCIP_RESULT.DIDNOTRUN}

    def branchexecps(self, allowaddcons):
        return {"result": SCIP_RESULT.DIDNOTRUN}
