#########################################################################
##   This file is part of the α,β-CROWN (alpha-beta-CROWN) verifier    ##
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
"""Branch and bound for input space split."""
from typing import Union, Tuple, List, Callable, Optional, Dict
import time
from datetime import datetime
import torch
from torch import Tensor
from numpy import ndarray
import math
import sys
import os

import arguments
from beta_CROWN_solver import LiRPANet
from auto_LiRPA.utils import (stop_criterion_batch_any, stop_criterion_all,
                              AutoBatchSize)
from utils import check_auto_enlarge_batch_size, Stats
from input_split.branching_domains import UnsortedInputDomainList
from input_split.attack import (massive_pgd_attack, check_adv,
                                attack_in_input_bab_parallel,
                                update_rhs_with_attack)
from input_split.branching_heuristics import input_split_branching
from input_split.clip import clip_domains
from input_split.split import get_split_depth, input_split_and_repeat
from input_split.utils import to_numpy, transpose_c_back, initial_verify_criterion
from input_split.branching_domains import InputDomainList, UnsortedInputDomainList

import numpy as np

global_lbs, dir_timestamp, global_vnnlib_id = None, None, None

def update_global_lb_record(global_lb: torch.Tensor) -> None:
    """

    @param global_lb:
    @return:
    """
    global_lb_np = global_lb.detach().cpu().numpy()
    global global_lbs
    if global_lbs is None:
        global_lbs = global_lb_np
    else:
        global_lbs = np.vstack([global_lbs, global_lb_np])
    return

# compute lbias at this point
def deconstruct_lbias(_x_L, _x_U, _lA, _dm_lb):
    _lA = _lA.flatten(2) # (batch, spec_dim, in_dim)
    xhat_vect = ((_x_U + _x_L) / 2).flatten(1) # (batch, in_dim)
    xhat_vect = xhat_vect.unsqueeze(2) # (batch, in_dim, 1)
    eps_vect = ((_x_U - _x_L) / 2).flatten(1) # (batch, in_dim)
    eps_vect = eps_vect.unsqueeze(2) # (batch, in_dim, 1)
    dm_lb_vect = _dm_lb.unsqueeze(2) # (batch, spec_dim, 1)
    _lbias = dm_lb_vect - (_lA.bmm(xhat_vect) - _lA.abs().bmm(eps_vect))
    return _lbias.squeeze(2) # (batch, spec_dim)

def reordered_batch_verification_input_split(
        d: InputDomainList,
        net: LiRPANet, batch:int, num_iter:int, decision_thresh:Tensor, shape:Optional[tuple]=None,
        bounding_method:str="crown", branching_method:str="sb",
        stop_func:Callable=stop_criterion_batch_any, split_partitions:int=2, stats:Stats=None):
    """
    Reordering of the batch_verification_input_split method
    @param d:                   Domain list
    @param net:                 Bounded neural network
    @param batch:               Number of effective batches to evaluate
    @param num_iter:            The current iteration number of the input BaB run
    @param decision_thresh:     The specification threshold to verify against
    @param shape:               The shape of the network's input
    @param bounding_method:     The method to use when bounding the subdomains of the network
    @param branching_method:    The branching heuristic to use when splitting on input dimensions
    @param stop_func:           Criterion to stop naive lower bound of network
    @param split_partitions:    The number of partitions to create for subdomains, currently is always 2 for input split
    @param stats:               Stats object to profile and capture statistics from this round of BaB
    @return:
    """

    input_split_args = arguments.Config["bab"]["branching"]["input_split"]
    split_hint = input_split_args['split_hint']
    enable_clip_domains = arguments.Config["bab"]['clip_n_verify']['clip_input_domain']['enabled'] or arguments.Config["bab"]['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']
    save_global_lbs = arguments.Config["debug"]["sanity_check"] == "Full+Graph"

    # STEP 1: pick out domains
    stats.timer.start("pickout")
    ret = d.pick_out_batch(batch, device=net.x.device)
    alphas, dm_lb, x_L, x_U, cs, thresholds, constraints, _ = ret
    pickout_batch = len(x_L)
    print(f"Current pickout batch: {pickout_batch}")
    stats.visited += pickout_batch
    stats.timer.add('pickout')

    if input_split_args["update_rhs_with_attack"]:
        stats.timer.start('update_rhs_with_attack')
        if arguments.Config['model']['with_jacobian']:
            model_to_attack = net.net
        else:
            model_to_attack = net.model_ori
        thresholds = update_rhs_with_attack(x_L, x_U, cs, thresholds, dm_lb, model_to_attack)
        stats.timer.add('update_rhs_with_attack')

    # STEP 2: Compute bounds for all domains
    stats.timer.start('bounding')
    ret = net.get_lower_bound_naive(
        dm_lb=dm_lb if input_split_args["compare_with_old_bounds"] else None, dm_l=x_L, dm_u=x_U, alphas=alphas,
        bounding_method=bounding_method, branching_method=branching_method,
        C=cs, stop_criterion_func=stop_func, thresholds=thresholds,
        constraints=constraints, stats=stats)
    dm_lb, alphas, lA, lbias, lb_crown = ret  # here alphas is a dict
    dm_lb = dm_lb.to(device=thresholds.device)  # ensures it is on the same device as it may be different
    lb_crown = lb_crown.to(device=thresholds.device)

    assert not dm_lb.isnan().any()
    # Add constraints if constrained concretize is enabled
    enable_constrained_concretize = arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']
    if enable_constrained_concretize:
        batch_size = x_L.shape[0]
        x_dim = x_L.flatten(1).shape[1]
        constraints_A = lA.reshape( (batch_size, -1, x_dim) ).detach()
        constraints_lbias = lbias = deconstruct_lbias(x_L, x_U, lA, lb_crown)
        constraints_b = (constraints_lbias - thresholds).detach()
        constraints = (constraints_A, constraints_b)
    else:
        constraints = None

    stats.timer.add('bounding')

    # STEP 2.5: Filter out verified subdomains
    stats.timer.start('filtering')
    # Since we have only bounded the domains and not clipped them, we only need to check thresholds
    ret_filt = UnsortedInputDomainList.filter_verified_domains(pickout_batch, dm_lb, x_L, x_U,
                                         alphas, cs, d.use_alpha, threshold=thresholds, lA=lA,
                                         lbias=lbias, constraints = constraints,
                                         check_thresholds=True, check_bounds=False)
    num_unverified_domains, dm_lb, x_L, x_U, alphas, cs, thresholds, lA, lbias, constraints, _ = ret_filt
    stats.timer.add('filtering')

    # when num_unverified_domains > 0, there are still unverified subdomains after filtering from step 2.5
    if num_unverified_domains > 0:

        # STEP 3: Make decisions
        stats.timer.start('decision')
        split_idx = input_split_branching(
            net, dm_lb, x_L, x_U, lA, thresholds,
            branching_method, stats.storage_depth, num_iter=num_iter
        )
        stats.timer.add('decision')

        # STEP 4: create new split domains.
        stats.timer.start('split')
        split_depth = get_split_depth(x_L, split_partitions=split_partitions)

        new_x_L, new_x_U, split_depth, cs, thresholds, dm_lb, alphas, lA, lbias, constraints = input_split_and_repeat(
            x_L, x_U, shape, split_depth, split_idx, split_partitions, split_hint,
            cs, thresholds, dm_lb, alphas, lA, lbias, constraints)
        stats.timer.add('split')

        # STEP 5: shrink these new domains
        if enable_clip_domains:
            stats.timer.start('clip')
            ret = clip_domains(new_x_L, new_x_U, thresholds, lA, dm_lb=None, lbias=lbias, calculate_dm_lb=True)
            new_x_L, new_x_U = ret
            stats.timer.add('clip')

        # STEP 6: Add new domains back to domain list.
        stats.timer.start('add_domain')
        # Clipping only updates the input bounds but not the thresholds
        d.add(dm_lb, new_x_L.detach(), new_x_U.detach(),
              alphas, cs, thresholds, constraints=constraints, split_idx=None, check_thresholds=False, check_bounds=True)
        stats.timer.add('add_domain')


    stats.timer.start('others')
    len_domains = len(d)

    if len_domains == 0:
        print("No domains left, verification finished!")
        if dm_lb is not None and len(dm_lb) > 0:
            dm_lb_min = dm_lb.min().item()
            print(f"The lower bound of last batch is {dm_lb_min}")
        _print_final_results(stats, 0)
        return decision_thresh.max() + 1e-7
    else:
        if input_split_args["skip_getting_worst_domain"]:
            # It can be costly to call get_topk_indices when the domain list is long
            worst_idx = 0
        else:
            worst_idx = d.get_topk_indices()
        worst_val = d[worst_idx]
        global_lb = worst_val[0] - worst_val[-1]
        _print_final_results(stats, len_domains)
        if not input_split_args["skip_getting_worst_domain"]:
            if 1 < global_lb.numel() <= 5:
                print(f"Current (lb-rhs): {global_lb}")
            else:
                print(f"Current (lb-rhs): {global_lb.max().item()}")

    # save global_lb if --sanity_check_with_graphs is set
    if save_global_lbs:
        update_global_lb_record(-1 * global_lb)

    if input_split_args["show_progress"]:
        print(f"Progress: {d.get_progess():.10f}")
    sys.stdout.flush()

    return global_lb

def _print_final_results(stats: Stats, len_domains: int):
    stats.timer.add('others')
    stats.timer.print()
    print("Length of domains:", len_domains)
    print(f"{stats.visited} domains visited")

def batch_verification_input_split(
        d: InputDomainList,
        net: LiRPANet, batch:int, num_iter:int, decision_thresh:Tensor, shape:Optional[tuple]=None,
        bounding_method:str="crown", branching_method:str="sb",
        stop_func:Callable=stop_criterion_batch_any, split_partitions:int=2,
        stats:Optional[Stats]=None):
    input_split_args = arguments.Config["bab"]["branching"]["input_split"]
    save_global_lbs = arguments.Config["debug"]["sanity_check"] == "Full+Graph"
    split_hint = input_split_args["split_hint"]

    # STEP 1: pick out domains
    stats.timer.start("pickout")
    ret = d.pick_out_batch(batch, device=net.x.device)
    alphas, dm_lb, x_L, x_U, cs, thresholds, constraints, split_idx = ret

    stats.timer.add('pickout')

    if input_split_args["update_rhs_with_attack"]:
        stats.timer.start('update_rhs_with_attack')
        if arguments.Config['model']['with_jacobian']:
            model_to_attack = net.net
        else:
            model_to_attack = net.model_ori
        thresholds = update_rhs_with_attack(x_L, x_U, cs, thresholds, dm_lb, model_to_attack)
        stats.timer.add('update_rhs_with_attack')

    # STEP 2: create new split domains.
    stats.timer.start('split')
    split_depth = get_split_depth(x_L, split_partitions=split_partitions)


    if not input_split_args["compare_with_old_bounds"]:
        dm_lb = None
    new_x_L, new_x_U, split_depth, cs, thresholds, dm_lb, alphas, _, _, constraints = input_split_and_repeat(
        x_L, x_U, shape, split_depth, split_idx, split_partitions, split_hint,
        cs, thresholds, dm_lb, alphas, lA=None, lbias=None, constraints=constraints)
    stats.timer.add('split')

    pickout_batch = len(new_x_L)
    print(f"Current pickout batch: {pickout_batch}")
    stats.visited += pickout_batch

    # STEP 3: Compute bounds for all domains and make decisions.

    # Use constraints from the splitted domain
    new_x_L, new_x_U, new_dm_lb, alphas, constraints, split_idx = get_bound_and_decision(
        net, dm_lb, new_x_L, new_x_U, alphas, cs, thresholds, constraints,
        bounding_method, branching_method, stop_func, num_iter, stats=stats
    )

    # STEP 4: Add new domains back to domain list.
    stats.timer.start('add_domain')
    # Take constraints as a splitable feature
    d.add(new_dm_lb, new_x_L.detach(), new_x_U.detach(),
          alphas, cs, thresholds, constraints, split_idx)
    stats.timer.add('add_domain')

    stats.timer.start('others')
    len_domains = len(d)


    if len_domains == 0:
        print("No domains left, verification finished!")
        if new_dm_lb is not None:
            new_dm_lb_min = new_dm_lb.min().item()
            print(f"The lower bound of last batch is {new_dm_lb_min}")
        _print_final_results(stats, 0)
        return decision_thresh.max() + 1e-7
    else:
        if input_split_args["skip_getting_worst_domain"]:
            # It can be costly to call get_topk_indices when the domain list is long
            worst_idx = 0
        else:
            worst_idx = d.get_topk_indices()
        worst_val = d[worst_idx]
        global_lb = worst_val[0] - worst_val[-1]
        _print_final_results(stats, len_domains)
        if not input_split_args["skip_getting_worst_domain"]:
            if 1 < global_lb.numel() <= 15:
                print(f"Current (lb-rhs): {global_lb}")
            else:
                print(f"Current (lb-rhs): {global_lb.max().item()}")

    # save global_lb if --sanity_check_with_graphs is set
    if save_global_lbs:
        update_global_lb_record(-1 * global_lb)

    if input_split_args["show_progress"]:
        print(f"Progress: {d.get_progess():.10f}")
    sys.stdout.flush()

    return global_lb


def get_bound_and_decision(net: LiRPANet, dm_lb: Tensor, x_L: Tensor, x_U: Tensor, alphas: List, cs: Tensor, thresholds: Tensor, constraints,
                           bounding_method: str, branching_method: str, stop_func: Callable,
                           num_iter: int,
                           stats: Stats
                           ) -> Tuple[Tensor, Tensor, Tensor, List, Tensor]:
    if stats is not None:
        stats.timer.start("bounding")
    enable_constrained_concretize = arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']

    # Previously we would run compute bounds twice, the first one is to compute the constraint matrices.
    # The second one is to perform constrained bound estimation
    # Now we take constraints as a domain-affiliated feature, which can alse be splitted when splitting the domains

    if constraints is not None:
        constraints_A, constraints_b = constraints
        assert not (constraints_A.isnan().any())
        assert not (constraints_b.isnan().any())

    ret = net.get_lower_bound_naive(
        dm_lb=dm_lb if arguments.Config["bab"]["branching"]["input_split"]["compare_with_old_bounds"] else None,
        dm_l=x_L, dm_u=x_U, alphas=alphas,
        bounding_method=bounding_method, branching_method=branching_method,
        C=cs, stop_criterion_func=stop_func, thresholds=thresholds,
        constraints=constraints, stats=stats)
    new_dm_lb, alphas, lA, lbias, lb_crown = ret  # here alphas is a dict
    assert not new_dm_lb.isnan().any()

    new_dm_lb = new_dm_lb.to(device=thresholds.device)  # ensures it is on the same device as it may be different
    lb_crown = lb_crown.to(device=thresholds.device)    
    constraints = None
    # Add constraints if constrained concretize is enabled
    if enable_constrained_concretize:
        batch_size = x_L.shape[0]
        x_dim = x_L.view((batch_size, -1)).shape[1]
        constraints_A = lA.reshape( (batch_size, -1, x_dim) ).detach()
        constraints_lbias = deconstruct_lbias(x_L, x_U, lA, lb_crown)
        constraints_b = (constraints_lbias - thresholds).detach()
        constraints = (constraints_A, constraints_b)
        
    stats.timer.add("bounding")

    # shrink these new domains
    enable_clip_domains = arguments.Config["bab"]['clip_n_verify']['clip_input_domain']['enabled'] or arguments.Config["bab"]['clip_n_verify']['clip_input_domain']["enable_constrained_concretize"]
    if enable_clip_domains:
        stats.timer.start("clip")
        ret = clip_domains(x_L, x_U, thresholds, lA, lb_crown, lbias)
        x_L, x_U = ret
        stats.timer.add("clip")

    stats.timer.start("decision")
    split_idx = input_split_branching(
        net, new_dm_lb, x_L, x_U, lA, thresholds,
        branching_method, stats.storage_depth, num_iter=num_iter
    )
    stats.timer.add("decision")

    return x_L, x_U, new_dm_lb, alphas, constraints, split_idx


def input_bab_parallel(net: LiRPANet, init_domain: Tensor, x: Tensor, rhs: Optional[Tensor]=None,
                       timeout:Optional[float]=None, max_iterations: Optional[int]=None,
                       vnnlib=None, c_transposed:bool=False, return_domains:bool=False,
                       vnnlib_meta:Optional[Dict]=None):
    """Run input split bab.

    c_transposed: bool, by default False, indicating whether net.c matrix has
        transposed between dim=0 and dim=1. As documented in abcrown.py bab(),
        if using input split, and if there are multiple specs with shared single input,
        we transposed the c matrix from [multi-spec, 1, ...] to [1, multi-spec, ...] so that
        net.build() process in input_bab_parallel could share intermediate layer
        bounds across all specs. If such transpose happens, c_transposed is set
        to True, so that after net.build() in this func, we can transpose c
        matrix back, repeat x_LB & x_UB, duplicate alphas, to prepare for input domain bab.
    """
    stats = Stats()
    start = time.time()

    # All supported arguments.
    global global_vnnlib_id
    global global_lbs
    global dir_timestamp

    if dir_timestamp is None:
        # Example format: "2024-08-27_14-30-45"
        dir_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    if vnnlib_meta is None:
        vnnlib_meta = {
            "property_idx": 0, "vnnlib_id": 0, "benchmark_name": None
        }
    property_idx = vnnlib_meta.get("property_idx")
    vnnlib_id = vnnlib_meta.get("vnnlib_id")
    benchmark_name = vnnlib_meta.get("benchmark_name")

    if global_vnnlib_id is None or vnnlib_id != global_vnnlib_id:
        # Reset global_lbs if a vnnlib_id has not been set (this is the very first vnnlib instance
        # being run) or if we have moved onto a new vnnlib_id instance
        global_vnnlib_id = vnnlib_meta.get("vnnlib_id", 0)
        global_lbs = None


    bab_args = arguments.Config["bab"]
    debug_args = arguments.Config["debug"]
    sanity_check_with_graphs = debug_args["sanity_check"] == "Full+Graph"
    branching_args = bab_args["branching"]
    input_split_args = branching_args["input_split"]

    timeout = timeout or bab_args["timeout"]
    batch = arguments.Config["solver"]["batch_size"]
    bounding_method = arguments.Config["solver"]["bound_prop_method"]
    init_bounding_method = arguments.Config["solver"]["init_bound_prop_method"]
    max_iterations = max_iterations or bab_args["max_iterations"]
    sort_domain_iter = bab_args["sort_domain_interval"]
    branching_method = branching_args["method"]
    adv_check = input_split_args["adv_check"]
    split_partitions = input_split_args["split_partitions"]
    use_clip_domains = arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enabled'] or arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']
    use_reordered_bab = arguments.Config['bab']['clip_n_verify']['clip_input_domain']["reorder_bab"]
    split_hint = input_split_args["split_hint"]
    enable_check_adv = arguments.Config["attack"]["input_split_check_adv"]["enabled"]
    enable_check_adv = (
        arguments.Config["attack"]["pgd_order"] != "skip" if enable_check_adv == "auto"
        else enable_check_adv == "true"
    )
    # For this reordering, we are adjusting the effective batch size since (2*batches) number of domains should
    # be bounded per iteration no matter the BaB order
    
    batch = 2*batch if use_reordered_bab else batch

    if init_bounding_method == "same":
        init_bounding_method = bounding_method

    if c_transposed or net.c.shape[1] == 1:
        # When c_transposed applied, previous checks have ensured that there is only single spec,
        # but we compressed multiple data samples to the spec dimension by transposing C
        # so we need "all" to be satisfied to stop.
        # Another case is there is only single spec, so batch_any equals to all. Then, we set to all
        # so we can allow prune_after_crown optimization
        stop_func = stop_criterion_all
    else:
        # Possibly multiple specs in each data sample
        stop_func = stop_criterion_batch_any

    stats.visited = 0

    # Perform initialization so that we can start the BaB loop. Initialization arguments are saved in a dictionary
    # in case we need to perform re-initialization when enhanced bounding is called (if enabled).
    initialization_args = {
        'net': net, 'init_domain': init_domain, 'x': x, 'stop_func': stop_func, 'rhs': rhs,
        'init_bounding_method': init_bounding_method, 'bounding_method': bounding_method, 'branching_method': branching_method,
        'use_clip_domains': use_clip_domains, 'use_reordered_bab': use_reordered_bab, 'c_transposed': c_transposed,
        'split_partitions': split_partitions, 'sort_index': input_split_args['sort_index'],
        'sort_descending': input_split_args['sort_descending'], 'presplit_domains': input_split_args['presplit_domains'],
        'vnnlib': vnnlib, 'split_hint': split_hint, 'pgd_order': arguments.Config["attack"]["pgd_order"],
        'num_iter': 0, 'stats': stats,
    }
    init_ret = input_bab_initialization(**initialization_args)
    result, domains, global_lb = init_ret

    auto_batch_size = AutoBatchSize(
        batch, net.device,
        enable=arguments.Config["solver"]["auto_enlarge_batch_size"])

    num_iter = 1
    enhanced_bound_initialized = False
    batch_verification_fn = reordered_batch_verification_input_split if use_reordered_bab else batch_verification_input_split
    while (result == "unknown" and len(domains) > 0
           and (max_iterations == -1 or num_iter <= max_iterations)):
        print(f"Iteration {num_iter}")
        # sort the domains every certain number of iterations
        if sort_domain_iter > 0 and num_iter % sort_domain_iter == 0:
            stats.timer.start("sort_domains")
            domains.sort()
            stats.timer.add("sort_domains")

        last_glb = global_lb.max()

        if enable_check_adv:
            if adv_check != -1 and stats.visited >= adv_check:
                stats.timer.start("adv_check")
                # check whether adv example found
                if arguments.Config['model']['with_jacobian']:
                    model_to_attack = net.net
                else:
                    model_to_attack = net.model_ori
                if check_adv(domains, model_to_attack, x, vnnlib=vnnlib):
                    return global_lb.max(), stats.visited, "unsafe"
                stats.timer.add("adv_check")

        batch_ = batch
        if branching_method == "brute-force" and num_iter <= input_split_args["bf_iters"]:
            batch_ = input_split_args["bf_batch_size"]
        auto_batch_size.record_actual_batch_size(min(batch_, len(domains)))
        global_lb = batch_verification_fn(
            domains, net, batch_,
            num_iter=num_iter, decision_thresh=rhs, shape=x.shape,
            bounding_method=bounding_method, branching_method=branching_method,
            stop_func=stop_func, split_partitions=split_partitions, stats=stats)
        batch = check_auto_enlarge_batch_size(auto_batch_size)

        # once the lower bound stop improving we change to solve alpha mode
        if (arguments.Config["solver"]["bound_prop_method"]
            != input_split_args["enhanced_bound_prop_method"]
            and time.time() - start > input_split_args["enhanced_bound_patience"]
            and global_lb.max().cpu() <= last_glb.cpu()
            and bounding_method != "alpha-crown"
            and not enhanced_bound_initialized
            and not input_split_args["skip_enhance"]
        ):
            enhanced_bound_initialized = True
            enhanced_ret = enhanced_bound_init(initialization_args, num_iter)
            global_lb, domains, branching_method, bounding_method = enhanced_ret

        if arguments.Config["attack"]["pgd_order"] != "skip":
            if time.time() - start > input_split_args["attack_patience"]:
                print("Perform PGD attack with massively random starts finally.")
                if arguments.Config['model']['with_jacobian']:
                    model_to_attack = net.net
                else:
                    model_to_attack = net.model_ori
                ret_adv = massive_pgd_attack(x, model_to_attack, vnnlib=vnnlib)[1]
                if ret_adv:
                    result = "unsafe"
                    break

        if time.time() - start > timeout:
            print("Time out!")
            break

        print(f"Cumulative time: {time.time() - start}\n")
        num_iter += 1

    if result == "unknown" and len(domains) == 0:
        result = "safe"

    # Save sanity_check graphs if --sanity_check_with_graphs was enabled
    if sanity_check_with_graphs:
        from input_split.sanity_check import save_sanity_check_graphs
        save_sanity_check_graphs(global_lbs, benchmark_name, vnnlib_id, property_idx, dir_timestamp)

    if return_domains:
        # Thresholds may have been updated by PGD attack so that different
        # domains may have different thresholds. Restore thresholds to the
        # default RHS for the sorting.
        domains.threshold._storage.data[:] = rhs
        domains.sort()
        if return_domains == -1:
            return_domains = len(domains)
        lower_bound, x_L, x_U = domains.pick_out_batch(
            return_domains, device="cpu")[1:4]
        return lower_bound, x_L, x_U
    else:
        del domains
        return global_lb.max(), stats.visited, result


def load_presplit_domains(domains: InputDomainList,
                          net: LiRPANet, bounding_method: str, branching_method: str, stop_func: Callable, stats: Stats):
    """

    :param domains:             Input domain list for storing subdomains during BaB.
    :param net:                 AutoLiRPA net object. This is a mutable object.
    :param bounding_method:     Bounding method to be used the original domain has been bounded and split.
    :param branching_method:    Input BaB branching method i.e. naive, sb, etc.
    :param stop_func:           lambda function defining stopping (verification) criteria
    :param stats:               Stats object to profile and capture statistics from this round of BaB
    """
    input_split_args = arguments.Config["bab"]["branching"]["input_split"]
    use_reordered_bab = arguments.Config['bab']['clip_n_verify']['clip_input_domain']["reorder_bab"]
    batch_size = arguments.Config["solver"]["batch_size"]
    batch_size = batch_size*2 if use_reordered_bab else batch_size
    ret = domains.pick_out_batch(len(domains))
    alphas, dm_lb, x_L, x_U, cs, thresholds, _, split_idx = ret

    presplit_dm_l, presplit_dm_u = torch.load(
        input_split_args["presplit_domains"])
    presplit_dm_l = presplit_dm_l.to(dm_lb)
    presplit_dm_u = presplit_dm_u.to(dm_lb)
    num_presplit_domains = presplit_dm_l.shape[0]
    print(f"Loaded {num_presplit_domains} pre-split domains")

    dm_lb = dm_lb.expand(batch_size, -1)
    cs = cs.expand(batch_size, -1, -1)
    thresholds = thresholds.expand(batch_size, -1)
    num_batches = (num_presplit_domains + batch_size - 1) // batch_size

    for i in range(num_batches):
        print(f"Pre-split domains batch {i+1}/{num_batches}:")
        x_L = presplit_dm_l[i*batch_size:(i+1)*batch_size]
        x_U = presplit_dm_u[i*batch_size:(i+1)*batch_size]
        size = x_L.shape[0]
        x_L, x_U, new_dm_lb, alphas, split_idx = get_bound_and_decision(
            net, dm_lb[:size], x_L, x_U, alphas, cs[:size], thresholds[:size],
            bounding_method, branching_method, stop_func, num_iter=1, stats=stats
        )
        num_domains_pre = len(domains)
        domains.add(new_dm_lb, x_L, x_U, alphas, cs[:size], thresholds[:size],
                    split_idx = None if use_reordered_bab else split_idx)
        print(f"{len(domains) - num_domains_pre} domains added, "
              f"{len(domains)} in total")
        print()

    print(f"{len(domains)} pre-split domains added out of {presplit_dm_l.shape[0]}")
    verified_ratio = 1 - len(domains) * 1. / presplit_dm_l.shape[0]
    print(f"Verified ratio: {verified_ratio}")

def input_bab_initialization(net: LiRPANet, init_domain: Tensor, x: Tensor, stop_func: Callable, rhs: Tensor, init_bounding_method: str,
                             bounding_method: str, branching_method: str, use_clip_domains: bool,
                             use_reordered_bab: bool, c_transposed: bool, split_partitions: int,
                             sort_index, sort_descending: bool, presplit_domains: bool, vnnlib,
                             split_hint: Union[Tensor, float], pgd_order: str, num_iter: int, stats: Stats
                             ) -> Tuple[str, InputDomainList,
Union[float, ndarray, Tensor]]:
    """

    Initialization that sets up input bab by bounding the original network, and returning the global lower bound
    as well as the domain list. A method since we may need to perform re-initialization when performing enhanced
    bound method.

    :param net:                     AutoLiRPA net object. This is a mutable object.
    :param init_domain:             Lower and upper bound of the input domain.
    :param x:                       Initial BoundedTensor of input.
    :param stop_func:               lambda function defining stopping (verification) criteria
    :param rhs:                     Right-hand side thresholds.
    :param init_bounding_method:    The bounding method to be used when bounding the original network.
    :param bounding_method:         Bounding method to be used the original domain has been bounded and split.
    :param branching_method:        Input BaB branching method i.e. naive, sb, etc.
    :param use_clip_domains:        If true, performs clipping on the input domains.
    :param use_reordered_bab:       If true, need to slightly alter the order of initialization.
    :param c_transposed:            If true, c was transposed since every property had exactly one term.
    :param split_partitions:        Partitions per node. split_partition=2 simply creates a binary tree.
    :param sort_index:              If given, we care about the gap between curr_lb and threshold along a particular
                                    specification dimension. Otherwise, we take the max gap along all specification
                                    dimensions.
    :param sort_descending:         True if domains in input domain lists should be sorted in descending order w.r.t.
                                    the max(curr_lb - threshold) or curr_lb[sort_index] - threshold[sort_index] gap.
    :param presplit_domains:        True if we should load in presplit domains and add them to the domain list.
    :param vnnlib:                  vnnlib structured data
    :param split_hint:              If given, we should first split along these values.
    :param pgd_order:               If order is 'after', perform attacks on the initial domain(s).
    :param num_iter:                Current iteration of input BaB
    :param stats:                   Stats object to profile and capture statistics from this round of BaB

    :return result:                 The 'safe'/'unsafe' result after bounding the original domain.
    :return domains:                Input domain list containing first set of split domains, ready to be used by input bab.
    :return global_lb:              Global lower bound attained after bounding the initial set of domains.
    """

    # get input domain lower/upper limits
    dm_l = x.ptb.x_L
    dm_u = x.ptb.x_U

    def _broadcast_dm(dm):
        if dm.shape[0] == 1 and net.c.shape[0] > 1:
            dm = dm.expand(net.c.shape[0], *[-1] * (dm.ndim - 1))
        assert dm.shape[0] == net.c.shape[0]
        return dm

    # c is expanded in build(). dm_l(u) must be expanded as well
    dm_l = _broadcast_dm(dm_l)
    dm_u = _broadcast_dm(dm_u)

    use_alpha = init_bounding_method.lower() == "alpha-crown" or bounding_method == "alpha-crown"

    if (dm_u - dm_l > 0).int().sum() == 1:
        branching_method = "naive"

    global_lb, ret = net.build(
        init_domain, x, stop_criterion_func=stop_func(rhs),
        bounding_method=init_bounding_method, decision_thresh=rhs, return_A=False)
    for node in net.net.get_splittable_activations():
        for preact_node in node.inputs:
            if hasattr(preact_node, "output_activations"):
                preact_node.output_activations.append(node)
            else:
                preact_node.output_activations = [node]

    if getattr(net.net[net.net.input_name[0]], "lA", None) is not None:
        lA = net.net[net.net.input_name[0]].lA.transpose(0, 1)
    else:
        lA = None
        if bounding_method == "sb":
            raise ValueError("sb heuristic cannot be used without lA.")
        if use_clip_domains:
            raise ValueError("clip domains cannot be used without lA.")

    if c_transposed:
        lA, global_lb, rhs, dm_l, dm_u = transpose_c_back(
            lA, global_lb, rhs, dm_l, dm_u, ret, net)

    result = "unknown"

    # shrink the initial dm_l and dm_u
    # if use_clip_domains and not use_reordered_bab:
    #     dm_l, dm_u = clip_domains(dm_l, dm_u, rhs, lA, global_lb)

    # compute storage depth
    min_batch_size = (
            arguments.Config["solver"]["min_batch_size_ratio"]
            * arguments.Config["solver"]["batch_size"]
    )
    max_depth = max(int(math.log(max(min_batch_size, 1)) // math.log(split_partitions)), 1)
    stats.storage_depth = min(max_depth, dm_l.shape[-1])

    initial_verified, remaining_index = initial_verify_criterion(global_lb, rhs)

    domains = UnsortedInputDomainList(
        # here this device is only used to perform some computation efficiently
        # data of all domains are stored on 'cpu' to save memory
        stats.storage_depth, arguments.Config["general"]["device"],
        use_alpha=use_alpha,
        sort_index=sort_index,
        sort_descending=sort_descending,
        use_split_idx=not use_reordered_bab
    )

    if initial_verified:
        result = "safe"
    else:
        # compute initial split idx
        split_idx = input_split_branching(
            net, global_lb, dm_l, dm_u, lA, rhs, branching_method, stats.storage_depth, num_iter=num_iter)
        alphas = ret["alphas"]
        cs = net.c
        thresholds = rhs
        if use_reordered_bab:
            if pgd_order == "after":
                # Perform pgd attacks on the original unverified domains before we split them. Passing domains post
                # split can produce many new subdomains, adding additional overhead
                attack_init_domains = [(None, l, u, c, r) for i, l, u, c, r in
                                       zip(range(len(dm_l)), dm_l, dm_u, cs, thresholds) if i in remaining_index]
                if arguments.Config['model']['with_jacobian']:
                    model_to_attack = net.net
                else:
                    model_to_attack = net.model_ori
                if attack_in_input_bab_parallel(model_to_attack, attack_init_domains, x, vnnlib=vnnlib):
                    print("pgd attack succeed in input_bab_parallel")
                    result = "unsafe"
            # filter out verified subdomains
            lbias = deconstruct_lbias(dm_l, dm_u, lA, global_lb)
            constraints = None
            # Construct constraints matrices if constrained concretize is enabled.
            if arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']:
                # Constraints matrices will come from the output node A matrices and b matrics from last CROWN call
                current_batch_size = dm_l.shape[0]
                current_x_dim = dm_l.view((current_batch_size, -1)).shape[1]
                constraints_A = lA.reshape((current_batch_size, -1, current_x_dim)).detach()
                constraints_b = (lbias - rhs).reshape((current_batch_size, -1)).detach()
                constraints = (constraints_A, constraints_b)
            filt_ret = UnsortedInputDomainList.filter_verified_domains(len(dm_l), global_lb, dm_l, dm_u, alphas,
                                                                        cs, use_alpha, thresholds, lA, lbias, constraints=constraints,
                                                                        split_idx=split_idx,
                                                                        remaining_index=remaining_index)
            _, global_lb, dm_l, dm_u, alphas, cs, thresholds, lA, lbias, constraints, split_idx = filt_ret
            # perform the initial split on the domains
            split_depth = get_split_depth(dm_l, split_partitions=split_partitions)
            dm_l, dm_u, split_depth, cs, thresholds, global_lb, alphas, lA, lbias, constraints = input_split_and_repeat(
                dm_l, dm_u, x.shape, split_depth, split_idx, split_partitions, split_hint,
                cs, thresholds, global_lb, alphas, lA, lbias, constraints)
            # shrink the initial dm_l and dm_u
            if use_clip_domains:
                dm_l, dm_u = clip_domains(dm_l, dm_u, thresholds, lA, global_lb, lbias, calculate_dm_lb=True)
            domains.add(global_lb, dm_l.detach(), dm_u.detach(),
                        alphas, cs, thresholds, constraints=constraints, split_idx=None, remaining_index=None,
                        check_thresholds=False, check_bounds=use_clip_domains,
                        )
        else:
            if use_clip_domains and not use_reordered_bab:
                dm_l, dm_u = clip_domains(dm_l, dm_u, rhs, lA, global_lb)
            constraints = None
            # Construct constraints matrices if constrained concretize is enabled.
            if arguments.Config['bab']['clip_n_verify']['clip_input_domain']['enable_constrained_concretize']:
                # Constraints matrices will come from the output node A matrices and b matrics from last CROWN call
                lbias = deconstruct_lbias(dm_l, dm_u, lA, global_lb)
                lbias = lbias - rhs
                current_batch_size = dm_l.shape[0]
                current_x_dim = dm_l.view((current_batch_size, -1)).shape[1]
                constraints = (lA.reshape((current_batch_size, -1, current_x_dim)).detach(), lbias.reshape((current_batch_size, -1)).detach())
            domains.add(global_lb, dm_l.detach(), dm_u.detach(),
                        alphas, cs, thresholds, constraints=constraints, split_idx=split_idx,
                        remaining_index=remaining_index, check_bounds=use_clip_domains)
            if arguments.Config["attack"]["pgd_order"] == "after":
                # The domain list currently only has the original unverified domains, we can perform
                # pgd attacks on them directly
                if arguments.Config['model']['with_jacobian']:
                    model_to_attack = net.net
                else:
                    model_to_attack = net.model_ori
                if attack_in_input_bab_parallel(model_to_attack, domains, x, vnnlib=vnnlib):
                    print("pgd attack succeed in input_bab_parallel")
                    result = "unsafe"
        if presplit_domains:
            assert not use_alpha
            load_presplit_domains(
                domains, net, bounding_method, branching_method, stop_func, stats
            )
        stats.visited += remaining_index.shape[0]

    return result, domains, global_lb



def enhanced_bound_init(initialization_args: dict, num_iter: int):
    """
    Resets input BaB to use alpha-crown.
    :param initialization_args:     Dictionary of initialization arguments. Should be same one as before so that
                                    we may slightly alter its parameters for the new enhanced method. Note that
                                    this object is mutable. If you need to retain the original parameters, a
                                    copy of the dictionary should be made before calling this function.

                                    Refer to input_bab_initialization() for the full list of arguments in this dict.

    :param num_iter:                The current iteration of BaB.
    :return global_lb:              Global lower bound attained after bounding the initial set of domains.
    :return domains:                Input domain list containing first set of split domains, ready to be used by input bab.
    :return branching_method:       String noting the enhanced branching method being used.
    :return bounding_method:        String noting the enhanced bounding method being used.
    """
    # update arguments now that we are initializing once again with different parameters
    input_split_args = arguments.Config["bab"]["branching"]["input_split"]
    branching_method = input_split_args["enhanced_branching_method"]
    bounding_method = input_split_args["enhanced_bound_prop_method"]
    initialization_args['bounding_method'] = initialization_args['init_bounding_method'] = input_split_args["enhanced_bound_prop_method"]
    initialization_args['branching_method'] = input_split_args["enhanced_branching_method"]
    initialization_args['presplit_domains'] = False  # not supported with alpha-crown
    initialization_args['pgd_order'] = False  # should have been performed already (if was enabled)
    initialization_args['num_iter'] = num_iter

    print(f"Using enhanced bound propagation method {bounding_method} "
            f"with {branching_method} branching.")

    # perform re-initialization with new parameters
    result, domains, global_lb = input_bab_initialization(**initialization_args)
    global_lb = global_lb.max()

    return global_lb, domains, branching_method, bounding_method
