#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
"""Various kinds of specifications for verification."""
from numpy import ndarray

import arguments
import torch
import numpy as np
from typing import Union

from beta_CROWN_solver import LiRPANet

class Specification:
    def __init__(self):
        self.num_outputs = arguments.Config['data']['num_outputs']
        # FIXME Do not use numpy. Use torch instead.
        self.rhs = np.array([arguments.Config['bab']['decision_thresh']])

    def construct_vnnlib(self):
        raise NotImplementedError


class SpecificationVerifiedAcc(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            c = (torch.eye(self.num_outputs)[label].unsqueeze(1)
                - torch.eye(self.num_outputs).unsqueeze(0))
            I = (~(label.unsqueeze(1) == torch.arange(
                    self.num_outputs).type_as(label.data).unsqueeze(0)))
            c = c[I].view(1, self.num_outputs - 1, self.num_outputs)
            new_c = []
            for ii in range(self.num_outputs - 1):
                new_c.append((c[:, ii], self.rhs))
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib


class SpecificationTarget(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            target_label = dataset['target_label'][example_idx_list[i]]
            c = torch.zeros([1, self.num_outputs])
            c[0, label] = 1
            c[0, target_label] = -1
            new_c = [(c, self.rhs)]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib


class SpecificationRunnerup(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            runnerup = dataset['runnerup'][example_idx_list[i]]
            c = torch.zeros([1, self.num_outputs])
            c[0, label] = 1
            c[0, runnerup] = -1
            new_c = [(c, self.rhs)]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib

class SpecificationPatchingWinnerDiff(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            thresh = arguments.Config['bab']['decision_thresh'][i] if isinstance(
                arguments.Config['bab']['decision_thresh'], list) else torch.tensor([arguments.Config['bab']['decision_thresh']])
            thresh = torch.tensor(thresh, dtype=torch.float32)
            fullnet_winner_logit = dataset['target_label'][example_idx_list[i]]
            pruned_net_winner = label

            outer_or = []

            c1 = torch.zeros([1, self.num_outputs])
            c1[0, pruned_net_winner] = 1
            c1_rhs = -thresh + fullnet_winner_logit
            c1_equation = (c1, c1_rhs)

            c2 = torch.zeros([1, self.num_outputs])
            c2[0, pruned_net_winner] = -1
            c2_rhs = -thresh - fullnet_winner_logit
            c2_equation = (c2, c2_rhs)

            outer_or.append(c1_equation)
            outer_or.append(c2_equation)

            vnnlib.append([(this_x_range, outer_or)])

        return vnnlib

class SpecificationPatchingRunnerup(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            thresh = arguments.Config['bab']['decision_thresh'][i] if isinstance(
                arguments.Config['bab']['decision_thresh'], list) else arguments.Config['bab']['decision_thresh']

            runnerup = dataset['runnerup'][example_idx_list[i]]
            c = torch.zeros([1, self.num_outputs])
            c[0, label] = 1
            c[0, runnerup] = -1
            new_c = [(c, torch.tensor([thresh]))]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib


class SpecificationDupnetRunnerup(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        offset = self.num_outputs
        self.num_outputs *= 2

        for i in range(len(example_idx_list)):
            # IN THIS CRITERION (WINNER-RUNNER) decision thresh is a vector (per sample)
            thresh = arguments.Config['bab']['decision_thresh'][i] if isinstance(
                arguments.Config['bab']['decision_thresh'], list) else arguments.Config['bab']['decision_thresh']

            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            runnerup = dataset['runnerup'][example_idx_list[i]]
            # Define indices for full network and pruned network constraints.
            full_net_winner, full_net_runner = label, runnerup
            pruned_net_winner, pruned_net_runner = label + offset, runnerup + offset

            # First constraint: full network:
            # runner - winner <= -rhs    <=>  winner - runner >= rhs
            c_full = torch.zeros([1, self.num_outputs])
            c_full[0, full_net_winner] = -1
            c_full[0, full_net_runner] = 1

            # Second constraint: pruned network:
            # winner - runner <= rhs     <=>  runner - winner >= -rhs
            c_pruned = torch.zeros([1, self.num_outputs])
            c_pruned[0, pruned_net_winner] = 1
            c_pruned[0, pruned_net_runner] = -1

            # Winner_full - Runner_full >= D  => Winner_pruned - Runner_pruned >= D
            # Negation of the above is:
            # Winner_full - Runner_full >= D AND  ~(Winner_pruned - Runner_pruned >= D)
            # <=>
            #
            # Winner_full - Runner_full >= D AND  Runner_pruned - Winner_pruned <= -D

            c_stacked = torch.cat([c_full, c_pruned], dim=0)
            new_c = [(c_stacked, torch.tensor([-thresh, thresh]))]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib

class SpecificationDupnetWinnerDiff(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        offset = self.num_outputs
        self.num_outputs *= 2
        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]

            full_net_winner = label
            pruned_net_winner = label + offset

            # Our verification aim: |Winner_f - Winner_p| <= rhs
            # We negate, therefore:  ## todo - think, should abs be used here?
            # |Winner_f - Winner_p| >= rhs
            # <=>
            # Winner_f - Winner_p >= rhs OR Winner_f - Winner_p <= -rhs
            # <=>
            # Winner_p - Winner_f <= -rhs OR Winner_f - Winner_p <= -rhs
            outer_or = []
            c1 = torch.zeros([1, self.num_outputs])
            c1[0, pruned_net_winner] = 1
            c1[0, full_net_winner] = -1
            c1_equation = (c1, -self.rhs)

            c2 = torch.zeros([1, self.num_outputs])
            c2[0, full_net_winner] = 1
            c2[0, pruned_net_winner] = -1
            c2_equation = (c2, -self.rhs)

            outer_or.append(c1_equation)
            outer_or.append(c2_equation)

            vnnlib.append([(this_x_range, outer_or)])
        return vnnlib



class SpecificationAbsMax(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        offset = self.num_outputs
        self.num_outputs *= 2
        vnnlib = []
        for i in range(len(example_idx_list)):
            this_x_range = x_range[i]

            # enforcing | f_logits[i] - p_logits[i] | < delta for all i
            # which is  (AND [i=1..n] |fl[i] - pl[i]|  < delta)
            # taking negation IS EQUIVALENT TO:
            #           (OR[i=1..n] (|fl[i] - pl[i]|  >= delta))
            #       <=> (OR[i=1..n] (fl[i] - pl[i] >= delta OR fl[i]-pl[i] <= -delta)
            #       <=> (OR[i=1..n] (- fl[i] + pl[i] <= -delta OR fl[i]-pl[i] <= -delta)
            outer_or = []
            for j in range(self.num_outputs // 2):
                c = torch.zeros([1, self.num_outputs])
                c[0, j + offset] = 1 # pl[j]
                c[0, j] = -1 #  - fl[j]
                outer_or.append((c, -self.rhs))
                c = torch.zeros([1, self.num_outputs])
                c[0, j + offset] = -1  # - pl[j]
                c[0, j] = 1  # fl[j]
                outer_or.append((c, -self.rhs))

            vnnlib.append([(this_x_range, outer_or)])
        return vnnlib

class SpecificationTripledAdvHybrid(Specification):
    """
    full  & formal nets: label beats ALL other classes
    informal net      : runner-up strictly beats label
    """
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        K = 10                       # classes per subnet
        formal_shift, informal_shift = 10, 20
        self.num_outputs            = 3 * K
        eps = -1e-5                  # <0 , strict ">"

        vnnlib = []

        for idx in example_idx_list:
            label    = int(dataset['labels'  ][idx])
            runnerup = int(dataset['runnerup'][idx])
            box      = x_range[idx]

            row_list, rhs_list = [], []

            #  full & formal: label > every other class
            for j in range(K):
                if j == label:
                    continue

                # full net row
                r = torch.zeros((1, self.num_outputs))
                r[0, label] = -1
                r[0, j]     = 1
                row_list.append(r)
                rhs_list.append(eps)       #  y_j - y_label < eps

                # formal net row
                r = torch.zeros((1, self.num_outputs))
                r[0, label + formal_shift] = -1
                r[0, j     + formal_shift] = 1
                row_list.append(r)
                rhs_list.append(eps)

            # informal: runner-up > label
            r = torch.zeros((1, self.num_outputs))
            r[0, label   + informal_shift]  = 1
            r[0, runnerup + informal_shift] = -1
            row_list.append(r)
            rhs_list.append(eps)                    #  y_label - y_runner ≤ eps

            # -pack into ONE (mat, rhs) so everything is an AND
            mat = torch.cat(row_list, dim=0)                # shape (N , 30)
            rhs = torch.tensor(rhs_list, dtype=mat.dtype,   # (N, )
                                 device=mat.device)

            vnnlib.append([(box, [(mat, rhs)])])            # AND of all rows

        return vnnlib


class SpecificationTripledAdvWinnerRunner(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        formal_offset = 10
        informal_offset = 20
        self.num_outputs = 30

        for i in range(len(example_idx_list)):
            label = dataset['labels'][example_idx_list[i]].view(1, 1)
            this_x_range = x_range[i]
            runnerup = dataset['runnerup'][example_idx_list[i]]

            # indices
            full_net_winner, full_net_runner = label, runnerup
            formal_net_winner, formal_net_runner = label + formal_offset, runnerup + formal_offset
            informal_net_winner, informal_net_runner = label + informal_offset, runnerup + informal_offset

            print(f"full_net_winner: {full_net_winner}, full_net_runner: {full_net_runner}")
            print(f"formal_net_winner: {formal_net_winner}, formal_net_runner: {formal_net_runner}")
            print(f"informal_net_winner: {informal_net_winner}, informal_net_runner: {informal_net_runner}")

            # full net constraint: winner > runner
            c_full = torch.zeros([1, self.num_outputs])
            c_full[0, full_net_winner] = -1
            c_full[0, full_net_runner] = 1

            # formal pruned net constraint: winner > runner
            c_formal = torch.zeros([1, self.num_outputs])
            c_formal[0, formal_net_winner] = -1
            c_formal[0, formal_net_runner] = 1

            # informal pruned net constraint (target): winner <= runner (negated for verifier)
            c_informal = torch.zeros([1, self.num_outputs])
            c_informal[0, informal_net_winner] = 1
            c_informal[0, informal_net_runner] = -1

            c_stacked = torch.cat([c_full, c_formal, c_informal], dim=0)
            rhs = torch.tensor([0.0, 0.0, 0.0])

            vnnlib.append([(this_x_range, [(c_stacked, rhs)])])

        return vnnlib


class SpecificationTripledAdvSameWinnerCompleteOnly(Specification):
    """
    INAPPLICABLE TO INCOMPLETE VERIFICAITON IN ABCROWN, CURRENTLY,

    full- & formal-nets :  label beats *all* other classes
    informal-net       :  ∃ j ≠ label  s.t.  y_j > y_label
    Encoded as      OR_j  [  AND  (full conditions ∧ formal conditions ∧ row_j) ]
    """
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        K = 10
        formal_shift, informal_shift = 10, 20
        self.num_outputs = 3 * K
        eps = 0        # change for strict “>”

        vnnlib = []

        for idx in example_idx_list:
            label = int(dataset['labels'][idx])
            box = x_range[idx]

            # pre-compute the AND rows shared by all branches
            shared_rows, shared_rhs = [], []
            for c in range(K):
                if c == label:
                    continue
                # full-net row: y_c − y_label <= eps   (label wins)
                r = torch.zeros((1, self.num_outputs))
                r[0, label] = -1
                r[0, c] = 1
                shared_rows.append(r)
                shared_rhs.append(eps)

                # formal-net row: y_c(formal) − y_label(formal) <= eps (label wins)
                r = torch.zeros((1, self.num_outputs))
                r[0, label + formal_shift] = -1
                r[0, c + formal_shift] = 1
                shared_rows.append(r)
                shared_rhs.append(eps)

            # stack once for reuse
            shared_mat = torch.cat(shared_rows, dim=0)
            shared_rhs = torch.tensor(shared_rhs, dtype=shared_mat.dtype,
                                      device=shared_mat.device)

            #  build the OR over possible new winners
            outer_or = []
            for j in range(K):
                if j == label:
                    continue

                # informal-net: y_label(inf) − y_j(inf) <= eps   (y_j > y_label)
                r = torch.zeros((1, self.num_outputs))
                r[0, label + informal_shift] = 1
                r[0, j + informal_shift] = -1

                mat = torch.cat([shared_mat, r], dim=0)
                rhs = torch.cat([shared_rhs,
                                 torch.tensor([eps], dtype=mat.dtype,
                                              device=mat.device)])

                outer_or.append((mat, rhs))        # one AND-clause for this j

            vnnlib.append([(box, outer_or)])        # OR of all clauses for this x

        return vnnlib

class SpecificationTripledAdvLabelBeatsTarget(Specification):
    """
     - full- & formal-net: label score >= every other class score
     - informal-net: label score >= score of a target class
    """

    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        # spec in plain words:
        # ─ full & formal nets
        #   label must strictly beat every other class
        #   ⇒  AND_{c ≠ label} (y_c − y_label ≤ eps)
        #
        # ─ informal net
        #   designated target class may NOT overtake the label
        #   ⇒            (y_target − y_label ≤ eps)
        #
        # Entire spec =  AND( full-net rows ∧ formal-net rows ∧ informal-target row )


        K = 10
        formal_shift, informal_shift = 10, 20
        self.num_outputs = 3 * K
        eps = 0.0                  # strict if < 0

        vnnlib = []

        for idx in example_idx_list:
            label = dataset['labels'][idx].view(1, 1)
            target_label = dataset['target_label'][idx]
            # target_label = torch.tensor([3])
            box = x_range[idx]

            # full & formal nets (label wins)
            rows, rhs = [], []
            for c in range(K):
                if c == label:
                    continue
                # full-net: y_c – y_label <= eps
                r = torch.zeros((1, self.num_outputs))
                r[0, c] = 1
                r[0, label] = -1
                rows.append(r)
                rhs.append(eps)

                # formal-net: y_c(formal) – y_label(formal) <= eps
                r = torch.zeros((1, self.num_outputs))
                r[0, c + formal_shift] = 1
                r[0, label + formal_shift] = -1
                rows.append(r)
                rhs.append(eps)

            #   y_label(inf) - y_target(inf) <= eps  =>  y_label <= y_target
            r = torch.zeros((1, self.num_outputs))
            r[0, label + informal_shift] = 1
            r[0, target_label + informal_shift] = -1
            rows.append(r)
            rhs.append(eps)

            mat = torch.cat(rows, dim=0)                        # AND of all rows
            rhs = torch.tensor(rhs, dtype=mat.dtype,
                               device=mat.device)

            # no OR – just one (mat, rhs) clause
            vnnlib.append([(box, [(mat, rhs)])])

        return vnnlib

class SpecificationAllPositive(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            this_x_range = x_range[i]
            c = torch.eye(self.num_outputs).unsqueeze(0)
            new_c = []
            for ii in range(self.num_outputs):
                new_c.append((c[:, ii], self.rhs))
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib


def construct_vnnlib(dataset, example_idx_list):
    X = dataset['X']
    x_lower = x_upper = None
    if arguments.Config['specification']['type'] == 'lp':
        perturb_epsilon = dataset['eps']
        if type(perturb_epsilon) == list:
            # Each example has different perturbations.
            perturb_epsilon = torch.cat(perturb_epsilon)
            perturb_epsilon = perturb_epsilon[example_idx_list]
        assert perturb_epsilon is not None
        # FIXME why flatten?
        if arguments.Config['specification']['norm'] == float('inf'):
            if dataset.get('data_max', None) is None:
                # perturb_eps is already normalized.
                x_lower = (X[example_idx_list] - perturb_epsilon).flatten(1)
                x_upper = (X[example_idx_list] + perturb_epsilon).flatten(1)
            else:
                x_lower = (X[example_idx_list] - perturb_epsilon).clamp(
                    min=dataset['data_min']).flatten(1)
                x_upper = (X[example_idx_list] + perturb_epsilon).clamp(
                    max=dataset['data_max']).flatten(1)
            x_range = torch.stack([x_lower, x_upper], -1).numpy()
        else:
            # TODO create classes to handle it generally
            x_range = []
            for idx in example_idx_list:
                x_item = {
                    'X': X[idx],
                    'eps': dataset['eps'],
                    'norm': dataset['norm'],
                }
                if not isinstance(x_item['X'], torch.Tensor):
                    x_item['X'] = torch.tensor(x_item['X'])
                if 'eps_min' in dataset:
                    x_item['eps_min'] = dataset['eps_min']
                x_item['data_min'] = x_item['X'] - dataset['eps']
                x_item['data_max'] = x_item['X'] + dataset['eps']
                if dataset.get('data_min', None) is not None:
                    x_item['data_min'] = x_item['data_min'].clamp(
                        min=dataset['data_min'])
                if dataset.get('data_max', None) is not None:
                    x_item['data_max'] = x_item['data_max'].clamp(
                        min=dataset['data_max'])
                x_item['data_min'] = x_item['data_min']
                x_item['data_max'] = x_item['data_max']
                x_range.append(x_item)
    elif (arguments.Config['specification']['type'] == 'box' or
            # Some old config files use "bound"; keep for compatibility.
            arguments.Config['specification']['type'] == 'bound'):
        x_lower = dataset['data_min'].flatten(1)
        x_upper = dataset['data_max'].flatten(1)
        x_range = torch.stack([x_lower, x_upper], -1).numpy()
    else:
        raise ValueError('Unsupported perturbation type ' +
                         arguments.Config['specification']['type'])

    # TODO rename "robustness_type", since the verification objective may
    # not be related to robustness.
    robustness_type = arguments.Config['specification']['robustness_type']
    if robustness_type == 'verified-acc':
        specification = SpecificationVerifiedAcc()
    elif robustness_type == 'specify-target':
        specification = SpecificationTarget()
    elif robustness_type == 'runnerup':
        specification = SpecificationRunnerup()
    elif robustness_type == 'patching-runnerup':
        specification = SpecificationPatchingRunnerup()
    elif robustness_type == 'patching-winner-diff':
        specification = SpecificationPatchingWinnerDiff()
    elif robustness_type == 'all-positive':
        specification = SpecificationAllPositive()
    elif robustness_type == 'dupnet-runnerup':
        specification = SpecificationDupnetRunnerup()
    elif robustness_type == 'dupnet-winner-diff':
        specification = SpecificationDupnetWinnerDiff()
    elif robustness_type == 'abs-max':
        specification = SpecificationAbsMax()
    elif robustness_type == 'tripled-adv-winner-runner':
        specification = SpecificationTripledAdvWinnerRunner()
    elif robustness_type == 'tripled-adv-hybrid':
        specification = SpecificationTripledAdvHybrid()
    elif robustness_type == 'tripled-adv-label-target':
        specification = SpecificationTripledAdvLabelBeatsTarget()
    elif robustness_type == 'tripled-adv-same-winner':
        specification = SpecificationTripledAdvSameWinnerCompleteOnly()
    else:
        raise ValueError(robustness_type)
    return specification.construct_vnnlib(
        dataset, x_range, example_idx_list)


def sort_targets_cls(batched_vnnlib, init_global_lb, init_global_ub, scores,
                     reference_alphas, lA, final_node_name, reverse=False):
    # TODO need minus rhs
    # To sort targets, this must be a classification task, and initial_max_domains
    # is set to 1.
    assert len(batched_vnnlib) == init_global_lb.shape[0] and init_global_lb.shape[1] == 1
    sorted_idx = scores.argsort(descending=reverse)
    batched_vnnlib = [batched_vnnlib[i] for i in sorted_idx]
    init_global_lb = init_global_lb[sorted_idx]
    init_global_ub = init_global_ub[sorted_idx]

    if reference_alphas is not None:
        for spec_dict in reference_alphas.values():
            for spec in spec_dict:
                if spec == final_node_name:
                    if spec_dict[spec].size()[1] > 1:
                        # correspond to multi-x case
                        spec_dict[spec] = spec_dict[spec][:, sorted_idx]
                    else:
                        spec_dict[spec] = spec_dict[spec][:, :, sorted_idx]

    if lA is not None:
        lA = {k: v[sorted_idx] for k, v in lA.items()}

    return batched_vnnlib, init_global_lb, init_global_ub, lA, sorted_idx


def trim_batch(model, init_global_lb, init_global_ub, reference_alphas_cp,
               orig_lower_bounds, orig_upper_bounds, reference_alphas, lA, property_idx,
               c, rhs):
    net = model.net
    optimize_disjuncts_separately = arguments.Config['solver']['optimize_disjuncts_separately']

    # FIXME (assigned to Kaidi, Jun 2023): this function might be wrong; it does not handles
    # the case with a few AND statements like yolo.
    # extract lower bound by (sorted) init_global_lb and batch size of initial_max_domains
    start_idx = property_idx * arguments.Config['bab']['initial_max_domains']
    lower_bounds = {}
    upper_bounds = {}
    if optimize_disjuncts_separately:
        for layer_name in orig_lower_bounds.keys():
            lower_bounds[layer_name] = orig_lower_bounds[layer_name][start_idx: start_idx + c.shape[0]]
            upper_bounds[layer_name] = orig_upper_bounds[layer_name][start_idx: start_idx + c.shape[0]]
        assert torch.all(lower_bounds[net.final_name] == init_global_lb[start_idx: start_idx + c.shape[0]])
        assert torch.all(upper_bounds[net.final_name] == init_global_ub[start_idx: start_idx + c.shape[0]])
    else:
        for layer_name in orig_lower_bounds.keys():
            lower_bounds[layer_name] = orig_lower_bounds[layer_name]
            upper_bounds[layer_name] = orig_upper_bounds[layer_name]
        lower_bounds[net.final_name] = init_global_lb[start_idx: start_idx + c.shape[0]]
        upper_bounds[net.final_name] = init_global_ub[start_idx: start_idx + c.shape[0]]
    if rhs.numel() > 1:
        if optimize_disjuncts_separately:
            raise NotImplementedError("Output constraints for disjunctions are not supported for rhs.numel() > 1")
    # trim reference slope by batch size of initial_max_domains accordingly
    if reference_alphas is not None:
        for m, spec_dict in reference_alphas.items():
            for spec in spec_dict:
                if spec == net.final_node().name:
                    if reference_alphas_cp[m][spec].size()[1] > 1:
                        # correspond to multi-x case
                        spec_dict[spec] = reference_alphas_cp[m][spec][
                            :, start_idx: start_idx + c.shape[0]]
                    else:
                        spec_dict[spec] = reference_alphas_cp[m][spec][
                            :, :, start_idx: start_idx + c.shape[0]]
    # trim lA by batch size of initial_max_domains accordingly
    if lA is not None:
        lA = {k: v[start_idx: start_idx + c.shape[0]] for k, v in lA.items()}
    return {
        'lA': lA, 'rhs': rhs, 'lower_bounds': lower_bounds, 'upper_bounds': upper_bounds
    }


def prune_by_idx(reference_alphas, init_verified_cond, final_name, lA_trim, x, data_min, data_max,
                 need_prune_lA, lower_bounds, upper_bounds, c):
    """
     Prune reference_alphas, lA_trim, x, data_min, data_max, lower_bounds, upper_bounds, c
     by init_verified_cond. Only keep unverified elements for next step bab or attack.
    """

    init_unverified_cond = ~init_verified_cond

    if reference_alphas is not None:
        LiRPANet.prune_reference_alphas(
            reference_alphas, init_unverified_cond, final_name)
    if need_prune_lA:
        lA_trim = LiRPANet.prune_lA(lA_trim, init_unverified_cond)

    if data_min.shape[0] > 1:
        # use [torch.where(~init_verified_cond)[0]] to prune x
        # when vnnlib has multiple different x
        # fixed: don't repeat x and then take [0:1]
        x, data_min, data_max = \
            x[torch.where(init_unverified_cond)[0]], \
            data_min[torch.where(init_unverified_cond)[0]], \
            data_max[torch.where(init_unverified_cond)[0]]

    lower_bounds[final_name] = lower_bounds[final_name][init_unverified_cond]
    upper_bounds[final_name] = upper_bounds[final_name][init_unverified_cond]
    c = c[torch.where(init_unverified_cond)[0]]

    return reference_alphas, lA_trim, x, data_min, data_max, lower_bounds, upper_bounds, c


def batch_vnnlib(vnnlib):
    """reorganize original vnnlib file, make x, c and rhs batch wise"""
    final_merged_rv = []

    init_d = {'x': [], 'c': [], 'rhs': [],
              'verify_criterion': [], 'attack_criterion': [] }
    target_labels = []

    for vnn in vnnlib:
        for mat, rhs in vnn[1]:
            if isinstance(vnn[0], dict):
                init_d['x'].append(vnn[0])
            else:
                init_d['x'].append(np.array(vnn[0]))
            init_d['c'].append(mat)
            init_d['rhs'].append(rhs)
            tmp_true_labels, tmp_target_labels = [], []
            for m in mat:

                target_label = np.where(m == -1)[-1]
                if len(target_label) != 0:
                    assert len(target_label) == 1
                    tmp_target_labels.append(target_label[0])
                else:
                    tmp_target_labels.append(None)

            target_labels.append(np.array(tmp_target_labels))

    if len(init_d['x']) == 0 or isinstance(init_d['x'][0], np.ndarray):
        # n, shape, 2; the batch dim n is necessary, even if n = 1
        init_d['x'] = np.array(init_d['x'])
    init_d['c'] = torch.concat(
        [(item if isinstance(item, torch.Tensor)
          else torch.tensor(item)).unsqueeze(0)
         for item in init_d['c']], dim=0)
    init_d['rhs'] = np.array(init_d['rhs'])  # n, n_output
    target_labels = np.array(target_labels)

    # batch_size = min(
    #         arguments.Config['solver']['batch_size'],
    #         arguments.Config['bab']['initial_max_domains'])
    # initial_max_domains can be much larger than batch_size if auto_enlarge_batch_size enabled
    batch_size = arguments.Config['bab']['initial_max_domains']

    total_batch = int(np.ceil(len(init_d['x']) / batch_size))
    print(f"Total VNNLIB file length: {len(init_d['x'])}, max property batch size: {batch_size}, total number of batches: {total_batch}")

    for i in range(total_batch):
        # [x, [(c, rhs, y, pidx)]], pidx can be none
        final_merged_rv.append([
            init_d['x'][i * batch_size: (i + 1) * batch_size],
            [(init_d['c'][i * batch_size: (i + 1) * batch_size],
              init_d['rhs'][i * batch_size: (i + 1) * batch_size],
              target_labels[i * batch_size: (i + 1) * batch_size]
            )]])

    return final_merged_rv


def sort_targets(batched_vnnlib, init_global_lb, init_global_ub,
                 attack_images=None, attack_margins=None, results=None,
                 model_incomplete=None):
    bab_attack_enabled = arguments.Config['bab']['attack']['enabled']
    sort_targets = arguments.Config['bab']['sort_targets']
    cplex_cuts = arguments.Config['bab']['cut']['enabled'] and arguments.Config['bab']['cut']['cplex_cuts']
    optimize_disjuncts_separately = arguments.Config['solver']['optimize_disjuncts_separately']
    reference_alphas = results.get('alpha', None)
    lA = results.get('lA', None)

    ret = None
    if bab_attack_enabled:
        # Sort specifications based on adversarial attack margins.
        ret = sort_targets_cls(
            batched_vnnlib, init_global_lb, init_global_ub, lA=lA,
            scores=attack_margins.flatten(), reference_alphas=reference_alphas,
            final_node_name=model_incomplete.net.final_node().name)
        attack_images = attack_images[:, :, ret[-1]]
    elif cplex_cuts:
        # need to sort pidx such that easier first according to initial alpha crown
        ret = sort_targets_cls(
            batched_vnnlib, init_global_lb, init_global_ub, lA=lA,
            scores=init_global_lb.flatten(), reference_alphas=reference_alphas,
            final_node_name=model_incomplete.net.final_node().name,
            reverse=True)
    elif sort_targets:
        # Sort specifications based on incomplete verifier bounds.
        ret = sort_targets_cls(
            batched_vnnlib, init_global_lb, init_global_ub, lA=lA,
            scores=init_global_lb.flatten(), reference_alphas=reference_alphas,
            final_node_name=model_incomplete.net.final_node().name)
    if ret:
        assert not optimize_disjuncts_separately, (
            "Sorting targets is currently not supported when disjuncts are optimized separately."
        )
        batched_vnnlib, init_global_lb, init_global_ub, lA = ret[:-1]

    return batched_vnnlib, init_global_lb, init_global_ub, lA, attack_images


def add_rhs_offset(
        vnnlib: list,
        rhs_offset: Union[np.ndarray, int, float] = None
) -> list:
    """
    Updates the second operand's offset value where rhs_offset is either a scalar that may be
    broadcast to ALL clauses, or rhs_offset is an array of offset values applied to each clause.
    @param vnnlib:      The vnnlib file formatted as a list object. Structure can be found in the
                        read_vnnlib.md.
    @param rhs_offset:  Scalar, array, or None. If array, it modifies the offsets in the clauses 
                        of the vnnlib file accordingly. If scalar, it is broadcast to all clauses.
    @return:            The modified vnnlib object
    """
    # If rhs_offset is None, return the original vnnlib
    if rhs_offset is None:
        return vnnlib

    # For debugging, add a print statement if sanity check is enabled
    if arguments.Config['debug']['sanity_check'] in ['Full', "Full+Graph"]:
        print('Add an offset to RHS for debugging:', rhs_offset)

    # Determine if rhs_offset is a scalar or array
    is_scalar = isinstance(rhs_offset, (int, float))
    
    updated_vnnlib = []
    k = 0  # Index counter if rhs_offset is an array

    for v in vnnlib:
        result = []
        for i in range(len(v[1])):
            if is_scalar:
                # If scalar, broadcast the same rhs_offset to all clauses
                item = (v[1][i][0], v[1][i][1] + rhs_offset + 1e-3)
            else:
                # If rhs_offset is an array, apply the offset to each clause
                item = (v[1][i][0], v[1][i][1] + rhs_offset[k:k + len(v[1][i][1])] + 1e-3)
                k += len(v[1][i][1])
            result.append(item)
        updated_vnnlib.append((v[0], result))
    
    return updated_vnnlib
