#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################

import torch
import numpy as np

from heuristics.base import NeuronBranchingHeuristic
from heuristics.utils import compute_ratio, get_preact_params
from utils import get_reduce_op, get_batch_size_from_masks


class BabsrBranching(NeuronBranchingHeuristic):
    def __init__(self, net):
        super().__init__(net)
        self.icp_score_counter = 0  
    
    def babsr_score_new(self, nlp_info, lower_bounds, upper_bounds, lAs,
                            mask, reduce_op, number_bounds, prioritize_alphas='none',pattern_lambda=0.1):
        """Compute branching scores for kfsb.
        lower_bounds: [lower_bounds1, lower_bounds2, ...], lower bounds for different pre-activation layers.
        upper_bounds: [upper_bounds1, upper_bounds2, ...], upper bounds for different pre-activation layers.
        lAs: list, A matrix used in CROWN for all pre-activation layers.
        batch: int, batch size for current branching.
        mask: list, mask indicates whether the neuron in this layer is unstable or not, 1: unstable, 0: stable.
        reduce_op: min() or max(), consider min or max info for two branches, similar to BFS (min) or DFS (max).
        number_bounds: int, the number of bounds that will output for one property.
        prioritize_alphas: 'none', 'positive', 'negative',  Prioritize splits with only positive/negative lA or none.

        return
        score: list, same structure as lower_bounds indicates the score for all neurons.
        intercept_tb: list, same as score's structure, only contain the intercept scores.
        """
        # e.g., using your nlp_info['z'] tensors
        pre_relu_layer_names = [relu_layer.inputs[0].name for relu_layer in self.net.net.relus]
        self.nlp_pattern = {}
        for name in range(len(pre_relu_layer_names)):
            self.nlp_pattern[name] =   (nlp_info['z'][name  ].squeeze(0).reshape(-1) > 1e-5).to(torch.bool).cpu()  
         
        batch = get_batch_size_from_masks(mask)
        score = []
        intercept_tb = []
        relu_idx = -1
        small_score_threshold = 1e-4
        big_constant = 1e6

        # strength of NLP pattern boost; you can set self.pattern_lambda in __init__
        #pattern_lambda = getattr(self, "pattern_lambda", 0.1)

        def normalize_scores(scores, normal_score_idx, reduced_score_idx, larger_is_better=True):
            #  We want to reduce all scores in the reduced_score_idx set, so they are no better than
            #  the scores in the normal_score_idx set.
            if larger_is_better:
                thresh = small_score_threshold
                # idx is a mask, setting irrelevant scores to 0. Valid scores are positive.
                best_score_in_reduced_set = torch.max(
                    scores * reduced_score_idx, dim=1).values
                worst_score_in_normal_set = torch.clamp_min((
                    torch.min(
                        scores * normal_score_idx
                        + (1.0 - normal_score_idx) * big_constant,
                        dim=1).values), thresh)
            else:
                thresh = -small_score_threshold
                # idx is a mask, setting irrelevant scores to 0. Valid scores are negative.
                best_score_in_reduced_set = torch.min(
                    scores * reduced_score_idx, dim=1).values
                worst_score_in_normal_set = torch.clamp_max(torch.max(
                    scores * normal_score_idx
                    - (1.0 - normal_score_idx) * big_constant,
                    dim=1).values, thresh)
            # Sizes are (batch,).
            ratio = torch.clamp_max(
                worst_score_in_normal_set / (best_score_in_reduced_set + thresh),
                1.0)
            # Make the scores in the reduced_score_idx set smaller.
            adjusted_scores = (scores * normal_score_idx
                            + scores * reduced_score_idx * ratio.unsqueeze(1))
            return adjusted_scores

        # Compute BaBSR scores, starting from the last layer.
        for layer_i, layer in enumerate(reversed(self.net.split_nodes)):
            assert len(self.net.split_activations[layer.name]) == 1
            layer = self.net.split_activations[layer.name][0][0]
            key = layer.inputs[0].name          # pre-activation tensor name, e.g. '/input', '/input-3'
            lA_key = layer.name
            this_layer_mask = mask[key].unsqueeze(1)  # [batch, 1, n_neurons]

            if prioritize_alphas == 'positive':
                # Prioritize splits with only positive lA.
                normal_score_mask = (lAs[lA_key] >= 0).view(batch, number_bounds, -1) * this_layer_mask
                reduced_score_mask = (lAs[lA_key] < 0).view(batch, number_bounds, -1) * this_layer_mask
            elif prioritize_alphas == 'negative':
                # Prioritize splits with only negative lA.
                normal_score_mask = (lAs[lA_key] <= 0).view(batch, number_bounds, -1) * this_layer_mask
                reduced_score_mask = (lAs[lA_key] > 0).view(batch, number_bounds, -1) * this_layer_mask
            elif prioritize_alphas != 'none':
                raise ValueError(f'Unknown prioritize_alphas parameter {prioritize_alphas}')

            ratio = lAs[lA_key]
            ratio_temp_0, ratio_temp_1 = compute_ratio(
                lower_bounds[key], upper_bounds[key])

            # ---- intercept score (backup) ----
            intercept_temp = torch.clamp(ratio, max=0)
            intercept_candidate = intercept_temp * ratio_temp_1.unsqueeze(1)
            reshaped_intercept_candidate = intercept_candidate.view(
                batch, number_bounds, -1) * this_layer_mask
            reshaped_intercept_candidate = reshaped_intercept_candidate.mean(1)
            if prioritize_alphas != 'none':
                adjusted_intercept_candidate = normalize_scores(
                    reshaped_intercept_candidate, normal_score_mask,
                    reduced_score_mask, larger_is_better=False)
            else:
                adjusted_intercept_candidate = reshaped_intercept_candidate
            intercept_tb.insert(0, adjusted_intercept_candidate)

            # ---- alpha/KFSB score (main score) ----
            b_temp = get_preact_params(layer)
            if not isinstance(b_temp, int):
                b_temp = b_temp.view(-1, *([1] * (ratio.ndim - 3)))
            b_temp = b_temp * ratio

            ratio_temp_0 = ratio_temp_0.unsqueeze(1)
            bias_candidate_1 = b_temp * (ratio_temp_0 - 1)
            bias_candidate_2 = b_temp * ratio_temp_0
            bias_candidate = reduce_op(bias_candidate_1, bias_candidate_2)  # max for babsr by default
            score_candidate = bias_candidate + intercept_candidate
            score_candidate = score_candidate.abs().view(batch, number_bounds, -1) * this_layer_mask
            score_candidate = score_candidate.mean(1)  # [batch, n_neurons]

            if prioritize_alphas != 'none':
                adjusted_score_candidate = normalize_scores(
                    score_candidate, normal_score_mask, reduced_score_mask,
                    larger_is_better=True)
                remaining_branches = normal_score_mask.sum(dim=1, dtype=torch.int32)
                print(f'layer {len(self.net.split_nodes) - layer_i} '
                    'remaining preferred branching variables: '
                    f'{remaining_branches[:10].tolist()}, '
                    f'avg {remaining_branches.sum().item() / remaining_branches.numel()}')
            else:
                adjusted_score_candidate = score_candidate

            '''# ---- NEW: boost / penalize unstable neurons according to NLP pattern ----
            if hasattr(self, "nlp_pattern") and key in self.nlp_pattern:
                print("boost")
                pat = self.nlp_pattern[key].to(adjusted_score_candidate.device).float()  # [n_neurons_total]
                if pat.ndim == 1:
                    # broadcast pattern to all batch elements
                    pat = pat.unsqueeze(0).expand(batch, -1)   # [batch, n_neurons_total]

                # mask[key]: [batch, n_neurons_total], 1 for unstable, 0 for stable
                unstable_mask = mask[key].to(adjusted_score_candidate.device).float()

                # pattern-consistent unstable neurons
                pattern_boost_mask = pat * unstable_mask      # 1 where unstable AND in NLP pattern

                # Adjust shape if necessary (defensive)
                if pattern_boost_mask.shape[1] > adjusted_score_candidate.shape[1]:
                    pattern_boost_mask = pattern_boost_mask[:, :adjusted_score_candidate.shape[1]]
                    unstable_mask = unstable_mask[:, :adjusted_score_candidate.shape[1]]
                elif pattern_boost_mask.shape[1] < adjusted_score_candidate.shape[1]:
                    pad = adjusted_score_candidate.shape[1] - pattern_boost_mask.shape[1]
                    pattern_boost_mask = torch.nn.functional.pad(pattern_boost_mask, (0, pad))
                    unstable_mask = torch.nn.functional.pad(unstable_mask, (0, pad))

                # off-pattern unstable neurons = unstable but not in NLP pattern
                off_pattern_mask = unstable_mask - pattern_boost_mask  # still 0/1

                # ---- stronger, scale-aware penalty/bonus ----
                # scale by score range so penalty is comparable to score differences
                # shape: [batch, 1]
                score_span = (adjusted_score_candidate.max(dim=1, keepdim=True).values -
                            adjusted_score_candidate.min(dim=1, keepdim=True).values).detach()
                score_span = score_span + 1e-6

                # weights: tune these
                lambda_pos =0.5 # boost pattern-consistent neurons
                lambda_neg = 0.5  # penalize off-pattern neurons

                bonus  =  lambda_pos * score_span * pattern_boost_mask
                penalty = lambda_neg * score_span * off_pattern_mask

                # final adjustment
                print('bonus')
                print('penalty')
                print(bonus)
                print(penalty)
                adjusted_score_candidate = adjusted_score_candidate + bonus - penalty

            # Store alpha/KFSB score for this layer (with pattern bias)
            score.insert(0, adjusted_score_candidate)'''

            # ---- NEW: boost unstable neurons consistent with NLP pattern ----
            # self.nlp_pattern is a dict mapping key (e.g. '/input') to a 1D bool tensor of length n_neurons
            if hasattr(self, "nlp_pattern") and key in self.nlp_pattern:
                pat = self.nlp_pattern[key].to(adjusted_score_candidate.device).float()  # [n_neurons_total]
                if pat.ndim == 1:
                    # broadcast pattern to all batch elements
                    pat = pat.unsqueeze(0).expand(batch, -1)   # [batch, n_neurons_total]

                # mask[key] is [batch, n_neurons_total], 1 for unstable, 0 for stable
                # only boost unstable neurons that are also True in the NLP pattern
                unstable_mask = mask[key].to(adjusted_score_candidate.device).float()  # [batch, n_neurons_total]
                pattern_boost_mask = pat * unstable_mask                               # [batch, n_neurons_total]

                # Adjust shape if necessary (defensive)
                if pattern_boost_mask.shape[1] > adjusted_score_candidate.shape[1]:
                    pattern_boost_mask = pattern_boost_mask[:, :adjusted_score_candidate.shape[1]]
                elif pattern_boost_mask.shape[1] < adjusted_score_candidate.shape[1]:
                    pad = adjusted_score_candidate.shape[1] - pattern_boost_mask.shape[1]
                    pattern_boost_mask = torch.nn.functional.pad(pattern_boost_mask, (0, pad))

                # Boost scores of pattern-consistent unstable neurons
                adjusted_score_candidate = adjusted_score_candidate + pattern_lambda * pattern_boost_mask

            # Store alpha/KFSB score for this layer (with pattern bias)
            score.insert(0, adjusted_score_candidate) 

        relu_idx -= 1

        return score, intercept_tb

    def babsr_score_old(
        self, nlp_info,lower_bounds, upper_bounds, lAs,
        mask, reduce_op, number_bounds,            # nlp_info = {'z': {layer_key: tensor}, optional 'pi_u','pi_l'}; may be None
        sigma=0.25, gamma=0.1, eta=0.5, eps=1e-12,
        prioritize_alphas='none'):
        """
        Pattern-aligned KFSB scoring built *only* from SR/FSB directional estimates.
        No β-CROWN duals are used.

        Returns:
        score_pa:  list[layer] -> (batch, neurons) pattern-aligned scores
        score_LR:  list[layer] -> dict with 'left' and 'right' directional SR scores (after masks/means)
        intercept_tb: same as original babsr_score (for logging / backup)
        """
        batch = get_batch_size_from_masks(mask)
        score_pa, score_LR, intercept_tb = [], [], []
        small_score_threshold, big_constant = 1e-4, 1e6
        print("start smart branching \n")
        def normalize_scores(scores, normal_score_idx, reduced_score_idx, larger_is_better=True):
            # identical to your original helper (kept here for locality)
            if larger_is_better:
                thresh = small_score_threshold
                best_reduced = torch.max(scores * reduced_score_idx, dim=1).values
                worst_normal = torch.clamp_min(
                    torch.min(scores * normal_score_idx + (1.0 - normal_score_idx) * big_constant, dim=1).values, thresh)
            else:
                thresh = -small_score_threshold
                best_reduced = torch.min(scores * reduced_score_idx, dim=1).values
                worst_normal = torch.clamp_max(
                    torch.max(scores * normal_score_idx - (1.0 - normal_score_idx) * big_constant, dim=1).values, thresh)
            ratio = torch.clamp_max(worst_normal / (best_reduced + thresh), 1.0)
            return (scores * normal_score_idx + scores * reduced_score_idx * ratio.unsqueeze(1))

        # === loop layers (same traversal as your babsr_score) ===
        for layer_i, layer_node in enumerate(reversed(self.net.split_nodes)):
            assert len(self.net.split_activations[layer_node.name]) == 1
            layer = self.net.split_activations[layer_node.name][0][0]
            key   = layer.inputs[0].name
            lA_key = layer.name

            this_mask = mask[key].unsqueeze(1)    # (b,1,n)
            ratio = lAs[lA_key]                   # A
            l, u = lower_bounds[key], upper_bounds[key]

            # optional alpha-sign prioritization masks (unchanged)
            if prioritize_alphas == 'positive':
                normal_idx  = (ratio >= 0).view(batch, number_bounds, -1) * this_mask
                reduced_idx = (ratio <  0).view(batch, number_bounds, -1) * this_mask
            elif prioritize_alphas == 'negative':
                normal_idx  = (ratio <= 0).view(batch, number_bounds, -1) * this_mask
                reduced_idx = (ratio >  0).view(batch, number_bounds, -1) * this_mask
            elif prioritize_alphas != 'none':
                raise ValueError(f'Unknown prioritize_alphas parameter {prioritize_alphas}')

            # ---- SR/FSB ingredients from bounds (no duals) ----
            # compute_ratio returns α(ℓ,u) and -β(ℓ,u) terms
            ratio_temp_0, ratio_temp_1 = compute_ratio(l, u)        # shapes (b,n)
            Aneg = torch.clamp(ratio, max=0)                        # A_-
            # Intercept part (same for both branches)
            intercept = (Aneg * ratio_temp_1.unsqueeze(1))          # (b, nb, n)

            # Bias/slope part for LEFT (z<=0) and RIGHT (z>=0)
            b_temp = get_preact_params(layer)
            if not isinstance(b_temp, int):
                b_temp = b_temp.view(-1, *([1] * (ratio.ndim - 3))) # broadcast to A
            bA = b_temp * ratio                                     # (b, nb, n)

            # Directional SR scores BEFORE reduce_op:
            # right:  φ_right = (α - 1); left: φ_left = α (see derivation)
            s_right = (bA * (ratio_temp_0.unsqueeze(1) - 1.0)) + intercept   # (b, nb, n)
            s_left  = (bA *  ratio_temp_0.unsqueeze(1))          + intercept

            # abs + mean over number_bounds + mask unstable
            s_right = s_right.abs().view(batch, number_bounds, -1)
            s_left  = s_left.abs().view(batch, number_bounds, -1)
            s_right = (s_right * this_mask).mean(1)                         # (b,n)
            s_left  = (s_left  * this_mask).mean(1)

            # optional prioritization normalization (kept consistent with original)
            if prioritize_alphas != 'none':
                s_right = normalize_scores(s_right, normal_idx, reduced_idx, larger_is_better=True)
                s_left  = normalize_scores(s_left,  normal_idx, reduced_idx, larger_is_better=True)

            # Save per-direction SR scores (after mean/mask)
            score_LR.insert(0, {'left':  s_left, 'right': s_right})

            # Backup intercept scores (for logging/compat; same as original)
            intercept_layer = (intercept.view(batch, number_bounds, -1) * this_mask).mean(1)
            if prioritize_alphas != 'none':
                intercept_layer = normalize_scores(intercept_layer, normal_idx, reduced_idx, larger_is_better=False)
            intercept_tb.insert(0, intercept_layer)

            # ---- Pattern alignment (uses NLP sign only; no β-CROWN duals) ---- 
            if (nlp_info is not None): # and (key in nlp_info.get('z', {})): 
                print('pattern alignment \n')
                print(layer_i)
                z_nlp = nlp_info['z'][layer_i]                         # (b,n)
                prefer_right = (z_nlp > 0).float()                 # ON => prefer z>=0
                width = (u - l).clamp_min(eps)

                # Optional extra confidence from NLP bound multipliers (not CROWN duals)
                pi_u = nlp_info.get('pi_u', {}).get(layer_i, torch.zeros_like(z_nlp))
                pi_l = nlp_info.get('pi_l', {}).get(layer_i, torch.zeros_like(z_nlp))
                conf = (z_nlp.abs()/width) #+ eta*((pi_u + pi_l)/width)

                g_pref = prefer_right * s_right + (1.0 - prefer_right) * s_left
                g_opp  = prefer_right * s_left  + (1.0 - prefer_right) * s_right

                #print(g_pref)
                #print(sigma)
                #print(gamma)
                #print(g_opp) 
                print(conf)
                rel_overrun = g_opp / (g_pref + eps)# torch.clamp_min((g_opp / (g_pref + eps)) - 1.0, 0.0) 
                score_layer =  g_pref  -gamma * rel_overrun
                print(g_pref)
                idx_g = torch.nonzero(g_pref)
                print(idx_g.squeeze())
                print(rel_overrun)
                idx_r = torch.nonzero(rel_overrun)
                print(idx_r)
            else:
                # No NLP info: fall back to standard KFSB proxy = max(left,right)
                score_layer = torch.maximum(s_left, s_right)

            # mask out stable neurons and append
            score_pa.insert(0, score_layer * mask[key])
        print('score_pa')
        print(score_pa)
        return score_pa, intercept_tb

    def babsr_score(self,nlp_info, lower_bounds, upper_bounds, lAs,
                    mask, reduce_op, number_bounds, prioritize_alphas='none'):
        """Compute branching scores for kfsb.
        lower_bounds: [lower_bounds1, lower_bounds2, ...], lower bounds for different pre-activation layers.
        upper_bounds: [upper_bounds1, upper_bounds2, ...], upper bounds for different pre-activation layers.
        lAs: list, A matrix used in CROWN for all pre-activation layers.
        batch: int, batch size for current branching.
        mask: list, mask indicates whether the neuron in this layer is unstable or not, 1: unstable, 0: stable.
        reduce_op: min() or max(), consider min or max info for two branches, similar to BFS (min) or DFS (max).
        number_bounds: int, the number of bounds that will output for one property.
        prioritize_alphas: 'none', 'positive', 'negative',  Prioritize splits with only positive/negative lA or none.

        return
        score: list, same structure as lower_bounds indicates the score for all neurons.
        intercept_tb: list, same as score's structure, only contain the  intercept scores.
        """
        batch = get_batch_size_from_masks(mask)
        score = []
        intercept_tb = []
        relu_idx = -1
        small_score_threshold = 1e-4
        big_constant = 1e6

        def normalize_scores(scores, normal_score_idx, reduced_score_idx, larger_is_better=True):
            #  We want to reduce all scores in the reduced_score_idx set, so they are no better than the scores in the normal_score_idx set.
            if larger_is_better:
                thresh = small_score_threshold
                # idx is a mask, setting irrelevant scores to 0. Valid scores are positive.
                best_score_in_reduced_set = torch.max(
                    scores * reduced_score_idx, dim=1).values
                worst_score_in_normal_set = torch.clamp_min(( # Setting irrelevant scores to inf.
                    torch.min(
                        scores * normal_score_idx
                        + (1.0 - normal_score_idx) * big_constant,
                        dim=1).values), thresh)
            else:
                thresh = -small_score_threshold
                # idx is a mask, setting irrelevant scores to 0. Valid scores are negative.
                best_score_in_reduced_set = torch.min(
                    scores * reduced_score_idx, dim=1).values
                worst_score_in_normal_set = torch.clamp_max(torch.max(
                    scores * normal_score_idx
                    - (1.0 - normal_score_idx) * big_constant,
                    dim=1).values, thresh)
            # Sizes are (batch,).
            ratio = torch.clamp_max(
                worst_score_in_normal_set / (best_score_in_reduced_set + thresh),
                1.0)
            # Make the scores in the reduced_score_idx set smaller.
            adjusted_scores = (scores * normal_score_idx
                               + scores * reduced_score_idx * ratio.unsqueeze(1))
            return adjusted_scores

        # Compute BaBSR scores, starting from the last layer.
        for layer_i, layer in enumerate(reversed(self.net.split_nodes)):
            assert len(self.net.split_activations[layer.name]) == 1
            layer = self.net.split_activations[layer.name][0][0]
            key = layer.inputs[0].name
            lA_key = layer.name
            this_layer_mask = mask[key].unsqueeze(1)
            if prioritize_alphas == 'positive':
                # Prioritize splits with only positive lA.
                normal_score_mask = (lAs[lA_key] >= 0).view(batch, number_bounds, -1) * this_layer_mask
                reduced_score_mask = (lAs[lA_key] < 0).view(batch, number_bounds, -1) * this_layer_mask
            elif prioritize_alphas == 'negative':
                # Prioritize splits with only positive lA.
                normal_score_mask = (lAs[lA_key] <= 0).view(batch, number_bounds, -1) * this_layer_mask
                reduced_score_mask = (lAs[lA_key] > 0).view(batch, number_bounds, -1) * this_layer_mask
            elif prioritize_alphas != 'none':
                raise ValueError(f'Unknown prioritize_alphas parameter {prioritize_alphas}')

            ratio = lAs[lA_key]
            ratio_temp_0, ratio_temp_1 = compute_ratio(
                lower_bounds[key], upper_bounds[key])

            # Intercept score, used as a backup score in BaBSR. A lower (more negative) score is better.
            intercept_temp = torch.clamp(ratio, max=0)
            intercept_candidate = intercept_temp * ratio_temp_1.unsqueeze(1)
            reshaped_intercept_candidate = intercept_candidate.view(
                batch, number_bounds, -1) * this_layer_mask
            # In case for AND clauses, there are multiple bounds outputs
            # we need to calculate mean over number_bounds dim to get a average score
            reshaped_intercept_candidate = reshaped_intercept_candidate.mean(1)
            if prioritize_alphas != 'none':
                adjusted_intercept_candidate = normalize_scores(
                    reshaped_intercept_candidate, normal_score_mask,
                    reduced_score_mask, larger_is_better=False)
            else:
                adjusted_intercept_candidate = reshaped_intercept_candidate
            # intercept_tb is a list of intercept scores, each with a array of (batch, neuron).
            intercept_tb.insert(0, adjusted_intercept_candidate)

            b_temp = get_preact_params(layer)
            # In some cases, bias=0, we can't treat it like tensors
            if not isinstance(b_temp, int):
                b_temp = b_temp.view(-1, *([1] * (ratio.ndim - 3)))
            b_temp = b_temp * ratio
            # Estimated bounds of the two sides of the bounds.
            ratio_temp_0 = ratio_temp_0.unsqueeze(1)
            bias_candidate_1 = b_temp * (ratio_temp_0 - 1)
            bias_candidate_2 = b_temp * ratio_temp_0
            bias_candidate = reduce_op(bias_candidate_1, bias_candidate_2)  # max for babsr by default
            score_candidate = bias_candidate + intercept_candidate
            score_candidate = score_candidate.abs().view(batch, number_bounds, -1) * this_layer_mask
            # In case for AND clauses, there are multiple bounds outputs
            # we need to calculate mean over number_bounds dim to get a average score
            score_candidate = score_candidate.mean(1)
            if prioritize_alphas != 'none':
                adjusted_score_candidate = normalize_scores(
                    score_candidate, normal_score_mask, reduced_score_mask,
                    larger_is_better=True)
                remaining_branches = normal_score_mask.sum(dim=1, dtype=torch.int32)
                print(f'layer {len(self.net.split_nodes) - layer_i} '
                      'remaining preferred branching variables: '
                      f'{remaining_branches[:10].tolist()}, '
                      f'avg {remaining_branches.sum().item() / remaining_branches.numel()}')
            else:
                adjusted_score_candidate = score_candidate
            # alpha score, the main score in BaBSR. A higher (more positive) score is batter.
            score.insert(0, adjusted_score_candidate)

            relu_idx -= 1

        return score, intercept_tb

    @torch.no_grad()
    def get_branching_decisions(self, nlp_info,domains, split_depth,
                                branching_reduceop='min',
                                prioritize_alphas='none',
                                sparsest_layer=0, max_info_threshold=0.001,
                                **kwargs):
        """
        choose the dimension to split on
        based on each node's contribution to the cost function
        in the KW formulation.

        sparsest_layer: if all layers are dense, set it to -1
        max_info_threshold: if the maximum score is below the threshold,
                            we consider it to be non-informative
        """

        lower_bounds, upper_bounds = domains['lower_bounds'], domains['upper_bounds']
        orig_mask, lAs, cs = domains['mask'], domains['lAs'], domains['cs']

        batch = get_batch_size_from_masks(orig_mask)
        # Mask is 1 for unstable neurons. Otherwise it's 0.
        mask = orig_mask
        reduce_op = get_reduce_op(branching_reduceop, with_dim=False)

        number_bounds = 1 if cs is None else cs.shape[1]
        score, intercept_tb = self.babsr_score(nlp_info,
            lower_bounds, upper_bounds, lAs, mask, reduce_op,
            number_bounds, prioritize_alphas)

        decision = [[] for _ in range(batch)]

        random_dict = {}
        for b in range(batch):
            mask_item = [mask[node.name][b] for node in self.net.split_nodes]
            new_score = [score[j][b] for j in range(len(score))]
            split_depth = min(split_depth, new_score[0].shape[0])
            max_info = [torch.topk(i, split_depth, 0) for i in new_score]

            max_info_index = [a[1] for a in max_info]
            max_info = [a[0] for a in max_info]  # [num_layer, split_depth]

            _, max_info_top_k_index = torch.topk(torch.cat(max_info, dim=0), split_depth)

            for l in range(split_depth):
                decision_layer = max_info_top_k_index[l].item() // split_depth
                decision_index = max_info_index[decision_layer][max_info_top_k_index[l] % split_depth].item()
                if decision_layer != sparsest_layer and max_info[decision_layer][0].item() > max_info_threshold:
                    decision[b].append((decision_layer, decision_index))
                    mask_item[decision_layer][decision_index] = 0
                else:
                    min_info = [[i, torch.min(intercept_tb[i][b], 0)] for i in range(len(intercept_tb)) if
                                torch.min(intercept_tb[i][b]) < -1e-4]

                    if len(min_info) != 0 and self.icp_score_counter < 2 and (
                    min_info[-1][0], min_info[-1][1][1].item()) not in decision[b]:
                        intercept_layer = min_info[-1][0]
                        intercept_index = min_info[-1][1][1].item()
                        self.icp_score_counter += 1
                        decision[b].append((intercept_layer, intercept_index))
                        mask_item[intercept_layer][intercept_index] = 0
                        if intercept_layer != 0:
                            self.icp_score_counter = 0
                    else:
                        random_dict[b] = random_dict.get(b, 0) + 1
                        for preferred_layer in np.random.choice(len(self.net.split_indices), len(self.net.split_indices), replace=False):
                            if len(mask_item[preferred_layer].nonzero(as_tuple=False)) != 0:
                                decision[b].append(
                                    (preferred_layer, mask_item[preferred_layer].nonzero(as_tuple=False)[0].item()))
                                mask_item[decision[b][-1][0]][decision[b][-1][1]] = 0
                                break
                        self.icp_score_counter = 0
        if random_dict:
            print(f'Random branching decision used for {{example_idx:n_random}}: {random_dict}')

        split_depth = min([len(d) for d in decision])

        decision = [[batch[i] for batch in decision] for i in
                    range(split_depth)]  # change the order of final decision to split_depth * batch
        decision = sum(decision, [])

        return decision, None, split_depth  # None for points
