#########################################################################
##                                                                     ##
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
'''Branch and bound for activation space split.'''
import time
import numpy as np
import torch
import copy
from pyomo.environ import * 

from branching_domains import BatchedDomainList, ShallowFirstBatchedDomainList, check_worst_domain
from auto_LiRPA.utils import (stop_criterion_batch_any, multi_spec_keep_func_all,
                              AutoBatchSize)
from auto_LiRPA.bound_ops import (BoundInput)
from attack.domains import SortedReLUDomainList
from attack.bab_attack import bab_loop_attack
from heuristics import get_branching_heuristic
from input_split.input_split_on_relu_domains import input_split_on_relu_domains, InputReluSplitter
from lp_mip_solver import batch_verification_all_node_split_LP
from cuts.cut_verification import cut_verification, get_impl_params
from cuts.cut_utils import fetch_cut_from_cplex, clean_net_mps_process, cplex_update_general_beta
from cuts.infered_cuts import BICCOS
from utils import (print_splitting_decisions, print_average_branching_neurons,
                   Stats, get_unstable_neurons, check_auto_enlarge_batch_size)
from prune import prune_alphas
import arguments 

import time
from typing import Dict, List, Optional, Tuple, Any
import torch
import torch.nn as nn
from pyomo.environ import (
    ConcreteModel, Var, Param, RangeSet, Reals, NonNegativeReals,
    Constraint, ConstraintList, Objective, minimize, SolverFactory, Suffix, value
)
from pyomo.opt import TerminationCondition

def get_split_depth(batch_size, min_batch_size, min_depth):
    # Here we check the length of current domain list.
    # If the domain list is small, we can split more layers.
    if batch_size < min_batch_size:
        # Split multiple levels, to obtain at least min_batch_size domains in this batch.
        return max(min_depth, int(
            np.log(min_batch_size / max(min_depth, batch_size)) / np.log(2)))
    else:
        return min_depth

def split_domain(nlp_info, net, domains, d, batch, impl_params=None, stats=None,
                 set_init_alpha=False, fix_interm_bounds=True,
                 branching_heuristic=None, iter_idx=None):
    solver_args = arguments.Config['solver']
    bab_args = arguments.Config['bab']
    branch_args = bab_args['branching']
    biccos_args = bab_args['cut']['biccos']
    biccos_enable = biccos_args['enabled']
    biccos_heuristic = biccos_args['heuristic']
    stop_func = stop_criterion_batch_any

    min_batch_size = min(
        solver_args['min_batch_size_ratio'] * solver_args['batch_size'],
        batch)
    batch = next(iter(d['lower_bounds'].values())).shape[0] 

    stats.timer.start('decision')
    if isinstance(domains, ShallowFirstBatchedDomainList) and domains.use_bfs:
        depth = biccos_args['multi_tree_branching']['k_splits']
    else:
        depth = 1
    split_depth = get_split_depth(batch, min_batch_size, depth)
    # Increase the maximum number of candidates for fsb and kfsb if there are more splits needed.
    branching_decision, branching_points, split_depth = (
        branching_heuristic.get_branching_decisions(nlp_info,
            d, split_depth, method=branch_args['method'],
            branching_candidates=max(branch_args['candidates'], split_depth),
            branching_reduceop=branch_args['reduceop'], skip_bound_propagation = False))
    print_average_branching_neurons(
        branching_decision, stats.implied_cuts, impl_params=impl_params)
    if len(branching_decision) < len(next(iter(d['mask'].values()))):
        print('all nodes are split!!')
        print(f'{stats.visited} domains visited')
        stats.all_node_split = True
        stats.all_split_result = 'unknown'
        if not solver_args['beta-crown']['all_node_split_LP']:
            global_lb = d['global_lb'][0] - d['thresholds'][0]
            for i in range(1, len(d['global_lb'])):
                if max(d['global_lb'][i] - d['thresholds'][i]) <= max(global_lb):
                    global_lb = d['global_lb'][i] - d['thresholds'][i]
            return global_lb, torch.inf
    split = {
        'decision': branching_decision,
        'points': branching_points,
    }
    if split['points'] is not None and not bab_args['interm_transfer']:
        raise NotImplementedError(
            'General branching points are not supported '
            'when interm_transfer==False')
    print_splitting_decisions(
        net, d, split_depth, split,
        verbose=arguments.Config['debug']['print_verbose_decisions'])
    stats.timer.add('decision')

    stats.timer.start('set_bounds')
    if isinstance(domains, ShallowFirstBatchedDomainList) and domains.use_bfs:
        net.build_history_and_set_bounds(d, split, impl_params=impl_params, mode='breadth')
    else:
        net.build_history_and_set_bounds(d, split, impl_params=impl_params, mode='depth')
    stats.timer.add('set_bounds')
    batch = len(split['decision'])
    stats.timer.start('solve')
    # Caution: we use 'all' predicate to keep the domain when multiple specs
    # are present: all lbs should be <= threshold, otherwise pruned
    # maybe other 'keeping' criterion needs to be passed here
    ret = net.update_bounds(
        d, fix_interm_bounds=fix_interm_bounds,
        stop_criterion_func=stop_func(d['thresholds']),
        multi_spec_keep_func=multi_spec_keep_func_all,
        beta_bias=branching_points is not None)
    stats.timer.add('solve')

    if (solver_args['beta-crown']['all_node_split_LP']
            and torch.any(torch.tensor(d['depths']) == net.tot_ambi_nodes)):
        # FIXME build_history_and_set_bounds doesn't return correct split
        # (just dummy elements) when split_depth > 1
        stats.all_split_result = 'unknown'
        if batch_verification_all_node_split_LP(net, d, ret, split, stats):
            stats.all_node_split = True
            stats.all_split_result = 'unsafe'
            return torch.inf

    if set_init_alpha:
        print('Setting the initial alpha')
        ret['alphas'] = prune_alphas(ret['alphas'], net.alpha_start_nodes)
        # We just want the data structure here, not the values
        domains.init_alpha = {
            k: {kk: vv[:, :, :1].detach().clone().to(net.x.device).to(
                torch.get_default_dtype()) for kk, vv in v.items()}
            for k, v in ret['alphas'].items()
        }
    else:
        if not fix_interm_bounds:
            ret['alphas'] = prune_alphas(ret['alphas'], net.alpha_start_nodes)

    # We have to add cuts now, because domains.add might modify the list of domains in ret
    if ret and bab_args['cut']['enabled'] and biccos_enable:
        # We only enforce cut usage for multi-tree-searching
        enforce_cut_usage = (
        isinstance(domains, ShallowFirstBatchedDomainList)
        and domains.use_bfs)
        # If disable_constraint_strengthening, set iter_idx to a very large value
        # to skip inference proceture
        iter_idx = iter_idx if biccos_args['constraint_strengthening'] else float('inf')
        net.biccos.update_cut(d, net, ret,
                            enforce_usage=enforce_cut_usage,
                            heuristic=biccos_heuristic,
                            iter_idx=iter_idx)

    stats.timer.start('add')
    old_d_len = len(domains)
    domains.add(ret, d, check_infeasibility=not fix_interm_bounds)
    stats.visited += len(domains) - old_d_len
    domains.print()
    stats.timer.add('add')
    del d
    return ret

def act_split_round(nlp_info,domains, net, batch, iter_idx, stats=None, impl_params=None,
                    branching_heuristic=None):
    bab_args = arguments.Config['bab']
    sort_domain_iter = bab_args['sort_domain_interval']
    recompute_interm = bab_args['recompute_interm']
    vanilla_crown = bab_args['vanilla_crown']
    spec_args = arguments.Config['specification']

    stats.timer.start('pickout')
    d = domains.pick_out(batch=batch, device=net.x.device, impl_params=impl_params)
    if vanilla_crown:
        d['history'] = None
    stats.timer.add('pickout')

    if bab_args['cut']['enabled'] and bab_args['cut']['cplex_cuts']:
        cplex_update_general_beta(net, d)

    ret_out = None  # new: will hold the dict returned by update_bounds()

    if d['mask'] is not None:
        # capture return from split_domain
        ret_or_lb = split_domain(nlp_info,
            net, domains, d, batch, impl_params=impl_params,
            stats=stats, fix_interm_bounds=not recompute_interm,
            branching_heuristic=branching_heuristic, iter_idx=iter_idx)

        # If split_domain returned a dict, it's the usual `ret`; otherwise it
        # may be a scalar/tensor (e.g., all-node-split fast path).
        if isinstance(ret_or_lb, dict):
            ret_out = ret_or_lb
        else:
            ret_out = None  # uncommon path; keep None

        print('Length of domains:', len(domains))
        stats.timer.print()

    if len(domains) == 0:
        print('No domains left, verification finished!')

    if sort_domain_iter > 0 and iter_idx % sort_domain_iter == 0:
        domains.sort()

    global_lb = check_worst_domain(domains)
    rhs_offset = spec_args['rhs_offset']
    if rhs_offset is not None:
        global_lb += rhs_offset
    if 1 < global_lb.numel() <= 5:
        print(f'Current (lb-rhs): {global_lb}')
    else:
        print(f'Current (lb-rhs): {global_lb.max().item()}')
    print(f'{stats.visited} domains visited')

    # ==== NEW: gather alpha and intermediate bounds from `ret_out` ====
    alphas = None
    interm_bounds = None
    if isinstance(ret_out, dict):
        # alphas for current batch of split domains
        alphas = ret_out.get('alphas', None) 

    # Return all three so callers can log/save them this round 
    return global_lb, alphas,  d['history']
 

def multi_tree_bab(net, domains, batch,
    stop_criterion, biccos_args, impl_params,
    stats, start_time):
    '''
    Usually, BaB uses a single binary tree. In multi-tree search, keep track of multiple trees,
    and each node may have multiple children. This allows us to e.g. explore both the splits (A, B),
    (A, C) and (D, C) in parallel. By doing so, we can generate more diverse BICCOS cuts.
    After the multi-tree search terminates, we drop all but one tree, which is pruned to become a
    binary tree. This tree is then used for the rest of the BaB process.
    In each iteration, we select the best n leaf nodes and perform k splits each.

    input:
        net: LirpaNet
        domains: ShallowFirstBatchedDomainList
        batch: int
        stop_criterion: callable
        biccos_args: dict
        impl_params: dict
        stats: Stats
        start_time: float
    '''
    shallowbranching_heuristic = get_branching_heuristic(net, 'kfsb')
    assert len(domains) == 1

    # At the end of the multi-tree search, we have to restore the initial domain
    initial_domain = domains.pick_out(batch=batch, device=net.x.device)
    initial_ret = net.update_bounds(
        initial_domain,
        fix_interm_bounds=True,
        stop_criterion_func=stop_criterion,
        multi_spec_keep_func=multi_spec_keep_func_all,
        beta_bias=False
    )
    domains.add(initial_ret, initial_domain, check_infeasibility=False)

    total_round = 0
    max_iter_shallow = biccos_args['multi_tree_branching']['iterations']
    num_domains = len(domains)
    # In rare cases, adding the initial domain back might prove it to be UNSAT.
    # This might happen due to randomnes in the gradient updates.
    # If it happens, we're done and don't need to proceed with regular BaB.
    if num_domains == 0:
        return
    assert num_domains == 1

    while (num_domains > 0 and total_round < max_iter_shallow):
        total_round += 1
        print(f'Shallow-BaB round {total_round}')
        act_split_round(domains, net, batch, iter_idx=total_round,
                impl_params=impl_params, stats=stats,
                branching_heuristic=shallowbranching_heuristic)
        num_domains = len(domains)
        print(f'Cumulative time: {time.time() - start_time}\n')

    # Drop current list of domains
    domains.use_bfs = False
    if len(domains) > 0:
        domains.pick_out(batch=len(domains), device=net.x.device)

    if not biccos_args['multi_tree_branching']['restore_best_tree']:
        domains.add(initial_ret, initial_domain, check_infeasibility=False)
    else:
        domains.restore_best_domains(initial_ret, initial_domain)
        # We might have added some domains that are UNSAT
        print('Shallow branching resets to n domains: ', len(domains))
        base_d = domains.pick_out(batch=len(domains), device=net.x.device)
        new_ret = net.update_bounds(
                base_d,
                fix_interm_bounds=True,
                stop_criterion_func=stop_criterion,
                multi_spec_keep_func=multi_spec_keep_func_all,
                beta_bias=False
            )
        domains.add(new_ret, base_d, check_infeasibility=False)
        print('After pruning, left: ', len(domains))

    print('\nBack to Regular BaB\n')
  
from typing import Dict, List, Tuple
 

def collect_unstable_alphas(ret_alphas: Dict[str, Dict[str, torch.Tensor]],
                            inputkey: List[str] = ['/8', '/10'],
                            outputkey: List[str] =  ['/input-3', '/11'] ,
                            idx: int = 0
                           ) -> Tuple[torch.Tensor,  List[Tuple[int, int]]]:
    """
    Concatenate α per neuron across the given layers into a matrix like (B, #unstable neurons) tensor.
    Returns:
      alphas_cat: (B, sum_neurons) tensor
      index_map: list of (layer_key, neuron_idx) for each column
    """
    per_layer  = [] 
    index_map: List[Tuple[int, int]] = [] 
    for l, (lk, subkey) in enumerate(zip(inputkey, outputkey)):  
        A = ret_alphas[lk][subkey][0,0,idx, :]   # (  N_lk)  
        per_layer.append(A) 
        index_map  += [(l , j) for j in range(A.shape[0])] 

    alphas_cat = torch.cat(per_layer, dim=0)  # (num_l N_lk) → (1, 568) in your case  
    return alphas_cat , index_map   

from typing import List, Tuple, Dict
import math

def split_alpha_sets(
    alphas_cat, 
    idx_map: List[Tuple[int, int]], 
    threshold: float =  math.tan(math.pi / 8.0) ):
    """
    Divide neurons into overset (alpha > threshold) and underset (alpha <= threshold).

    Args:
        alphas_cat: Tensor of shape (num_layer, N), from collect_unstable_alphas.
        idx_map: List of (layer_key, neuron_idx), length N.
        threshold: float, cutoff for overset. 

    Returns:
        dict with:
            'overset': list of (layer_key, neuron_idx) with alpha > threshold
            'underset': list of (layer_key, neuron_idx) with alpha <= threshold
    """ 

    values = alphas_cat.detach().cpu().numpy().flatten()
    overset, underset = [], [] 
    for (layer, j), val in zip(idx_map, values):
        if val > threshold:
            overset.append((layer , j+1))
        else:
            underset.append((layer , j+1))  
    return overset, underset 


def model_info(model):
    W_list = []
    b_list = []
    in_list = []
    out_list = []
    for name, module in model.named_children():
        if isinstance(module , torch.nn.Linear): 
            W_list.append(module.weight.data.cpu().numpy())
            b_list.append(module.bias.data.cpu().numpy())
            in_list.append(module.in_features )
            out_list.append(module.out_features ) 
    return W_list, b_list, in_list , out_list  

def classify_neurons_from_bounds( lbs_key, ubs_key, batch_idx    ):
    """
        Classify neurons based on intermediate bounds:
        - I_n: neurons with negative upper bounds (always inactive)
        - I_p: neurons with positive lower bounds (always active)  
        - I_0: neurons that can be either active or inactive
        
        Args:
            lbs: List of lower bound arrays for each layer
            ubs: List of upper bound arrays for each layer
            out_list: List of output sizes for each layer
            
        Returns:
            Tuple of (I_p, I_n, I_0) lists containing (layer_idx, neuron_idx) pairs
    """
    #lbs_key, ubs_key = interm_bounds['lower_bounds'], interm_bounds['upper_bounds']
    I_p, I_n, I0 = [], [], [] 
    lbs, ubs = [], []    
    for layer_idx, v in enumerate(lbs_key):  
        layer_lbs = lbs_key[v][batch_idx,:]  # Shape: [num_neurons]
        layer_ubs = ubs_key[v][batch_idx,:]  # Shape: [num_neurons]
        lbs.append(layer_lbs.reshape(1,-1).detach().cpu().numpy())
        ubs.append(layer_ubs.reshape(1,-1).detach().cpu().numpy())

        if layer_idx >= len(lbs_key)-1: break # the last layer has no relu    
        for neuron_idx in range(len(layer_lbs) ):
                lb_val = layer_lbs[neuron_idx]
                ub_val = layer_ubs[neuron_idx]
                        
                if ub_val < 0:
                            # Upper bound is negative -> always inactive (ReLU output = 0)
                    I_n.append((layer_idx, neuron_idx + 1))  # +1 for 1-indexing
                elif lb_val > 0:
                            # Lower bound is positive -> always active (ReLU output = input)
                    I_p.append((layer_idx, neuron_idx  + 1))  # +1 for 1-indexing
                else:
                            # Can be either active or inactive
                    I0.append((layer_idx, neuron_idx + 1 ))  # +1 for 1-indexing
    return I_p, I_n, I0, lbs, ubs 

def classify_neurons_from_bounds_raw( model, image,lb,ub,  norm = float("inf"), method='CROWN' ): 
        lbs, ubs = self.compute_interm_bounds(model, image, lb,ub,   norm = norm, method=method )
        I_p, I_n, I0 = [], [], []
            
            # Only classify hidden layers (exclude output layer)
        num_hidden_layers = len(lbs) -1
            
        for layer_idx in range(num_hidden_layers):
            if layer_idx < len(lbs) and layer_idx < len(ubs):
                layer_lbs = lbs[layer_idx][0]  # Shape: [num_neurons]
                layer_ubs = ubs[layer_idx][0]  # Shape: [num_neurons]
                        
                for neuron_idx in range(len(layer_lbs)):
                    lb_val = layer_lbs[neuron_idx]
                    ub_val = layer_ubs[neuron_idx]
                            
                    if ub_val < 0:
                                # Upper bound is negative -> always inactive (ReLU output = 0)
                        I_n.append((layer_idx, neuron_idx + 1))  # +1 for 1-indexing
                    elif lb_val > 0:
                                # Lower bound is positive -> always active (ReLU output = input)
                        I_p.append((layer_idx, neuron_idx  + 1))  # +1 for 1-indexing
                    else:
                                # Can be either active or inactive
                        I0.append((layer_idx, neuron_idx + 1 ))  # +1 for 1-indexing
        return I_p, I_n, I0, lbs, ubs 

# Define your Pyomo model
def find_I(m,    eps,I_p, I_n, I0):
    err = []
    L = len(I0) 
    I0_new = I0.copy()
    for k in range(L-1):
        i, j= I0[k]  
        if value(m.component("s1_"+str(i)+"_"+str(j)))  <  np.sqrt(eps) and value(m.component("s2_"+str(i)+"_"+str(j)) ) >= np.sqrt(eps)  :
            I_n.append((i,j))
            I0_new.remove((i,j))
        elif value(m.component("s2_"+str(i)+"_"+str(j)))  <  np.sqrt(eps) and value(m.component("s1_"+str(i)+"_"+str(j)) ) >= np.sqrt(eps)  :
            I_p.append((i, j))
            I0_new.remove((i,j))
        elif value(m.component("s2_"+str(i)+"_"+str(j)))  >  np.sqrt(eps) and value(m.component("s1_"+str(i)+"_"+str(j)) ) >= np.sqrt(eps)  :
            err.append((i, j))
        else:
            continue
    print('total neuron is', len(I_p) + len(I_n) + len(I0_new))# == 100
    return I_p, I_n, I0_new
  
 
def _pack_nlp_info(m, L, layer_keys=None):
    """
    Pack NLP information from Pyomo model into torch tensors.

    Returns a dict:
        {
          'z':    {key: (1, n_j)},
          's1':   {key: (1, n_j)},
          's2':   {key: (1, n_j)},
          'width':{key: (1, n_j)},
        }
    Only entries that actually exist in the model are filled.
    """
    nlp_info = {'z': {}, 's1': {}, 's2': {}, 'width': {}}

    for kk in range(L - 1):
        J = getattr(m, f"J_{kk}", None)
        if J is None:
            continue

        key = layer_keys[kk] if (layer_keys is not None and kk < len(layer_keys)) else kk

        # pre-activations
        z = getattr(m, f"z_{kk}", None)
        z_vals = [float(value(z[j])) if z is not None else 0.0 for j in J]
        nlp_info['z'][key] = torch.tensor(z_vals, dtype=torch.float32).unsqueeze(0)

        # width = ubz - lbz
        lbz = getattr(m, f"lbz_{kk}", None)
        ubz = getattr(m, f"ubz_{kk}", None)
        if lbz is not None and ubz is not None:
            width_vals = [float(value(ubz[j]) - value(lbz[j])) for j in J]
            nlp_info['width'][key] = torch.tensor(width_vals, dtype=torch.float32).unsqueeze(0)

        # s1, s2 may not exist for fixed neurons (in I_p or I_n)
        s1_layer = []
        s2_layer = []
        for j in J:
            s1_name = f"s1_{kk}_{j}"
            s2_name = f"s2_{kk}_{j}"
            if hasattr(m, s1_name) and hasattr(m, s2_name):
                s1_var = getattr(m, s1_name)
                s2_var = getattr(m, s2_name)
                s1_layer.append(float(value(s1_var)))
                s2_layer.append(float(value(s2_var)))
            else:
                # fixed ReLU: store 0 to keep shapes consistent
                s1_layer.append(0.0)
                s2_layer.append(0.0)

        nlp_info['s1'][key] = torch.tensor(s1_layer, dtype=torch.float32).unsqueeze(0)
        nlp_info['s2'][key] = torch.tensor(s2_layer, dtype=torch.float32).unsqueeze(0)

    return nlp_info
 

from pyomo.environ import value

def _update_bounds_for_phases(m, pos_neurons, neg_neurons):
    """
    For neurons that moved from I0 -> I_p / I_n, tighten pre-activation bounds:

        (k, j) in pos_neurons:  z_{k,j} >= 0  => lbz_k[j] = max(lbz_k[j], 0)
        (k, j) in neg_neurons:  z_{k,j} <= 0  => ubz_k[j] = min(ubz_k[j], 0)
    """
    # Positive phase: tighten lower bound to >= 0
    for (k, j) in pos_neurons:
        lbz_name = f"lbz_{k}"
        if hasattr(m, lbz_name):
            lbz = getattr(m, lbz_name)
            if j in lbz:
                old = value(lbz[j])           # <-- use value() instead of float(...)
                new = max(old, 0.0)
                lbz[j] = new                  # mutable Param, so direct assignment is fine

    # Negative phase: tighten upper bound to <= 0
    for (k, j) in neg_neurons:
        ubz_name = f"ubz_{k}"
        if hasattr(m, ubz_name):
            ubz = getattr(m, ubz_name)
            if j in ubz:
                old = value(ubz[j])           # <-- use value() here as well
                new = min(old, 0.0)
                ubz[j] = new

   
def NLP_solver_output(u_best, over_set, under_set,
                      model, lb, ub, spec, lbs, ubs, I_p, I_n, I0, eps=1e-5,
                      layer_keys=None ):
    """
    Solve the complementarity-based NLP and return:
      - status: 'solvable' | 'infeasible' | 'maxtimelimit' | 'failed'
      - fmin:   objective value (if solvable)
      - runtime: wall-clock seconds (if solvable)
      - m:      the Pyomo model (for inspection)
      - nlp_info: dict of torch tensors keyed by pre-activation layer key:
          {
            'z':    {key: (1, n_j)},          # pre-activations at NLP point
            'pi_u': {key: (1, n_j)},          # dual on z_j - u_j <= 0  (>=0)
            'pi_l': {key: (1, n_j)},          # dual on l_j - z_j <= 0  (>=0)
            's1':   {key: (1, n_j)},          # optional, 0 if var absent
            's2':   {key: (1, n_j)},
            'width':{key: (1, n_j)},          # u_j - l_j
          }

    Notes:
    - Duals are read with a ≤ 0 canonical form so signs are consistent.
    - layer_keys: optional list of strings for KFSB layer keys; if None, uses integers 0..L-2.
    """

    # ------------------ unpack model structure ------------------
    W_list, b_list, in_list, out_list = model_info(model)  # assumed available in your code
    label = int((spec[0][0] == 1).nonzero(as_tuple=False))
    attack = int((spec[0][0] == -1).nonzero(as_tuple=False))
    L = len(W_list)                      # number of layers incl. output affine
    input_size = in_list[0]

    if layer_keys is None:
        # keys expected by your branching code; adapt if it expects names
        layer_keys = list(range(L-1))    # one key per pre-activation layer

    # ------------------ build Pyomo model ------------------
    m = ConcreteModel()
    m.dual = Suffix(direction=Suffix.IMPORT)  # (1) enable dual import

    # input bounds
    lb_numpy, ub_numpy = lb.cpu().numpy(), ub.cpu().numpy()
    input_bounds = {i+1: (float(lb_numpy[0][i]), float(ub_numpy[0][i])) for i in range(input_size)}

    m.constraints = ConstraintList()
    m.input = Var(RangeSet(1, input_size), domain=Reals, bounds=input_bounds)

    # hidden/pre-activation layers: k = 0..L-2
    for k in range(L-1):
        m.add_component(f"I_{k}", RangeSet(1, in_list[k]))     # input indices to layer k
        m.add_component(f"J_{k}", RangeSet(1, out_list[k]))    # output (= preact) indices of layer k
        I = getattr(m, f"I_{k}")
        J = getattr(m, f"J_{k}")

        # variables for pre- and post-ReLU
        m.add_component(f"z_{k}",    Var(J, domain=Reals))             # pre-activation
        m.add_component(f"zhat_{k}", Var(J, domain=NonNegativeReals))  # post-ReLU (>=0)
        z    = getattr(m, f"z_{k}")
        zhat = getattr(m, f"zhat_{k}")

        # parameters W, b (affine)
        r = W_list[k].shape[0]; c = W_list[k].shape[1]
        W_ini = {(i+1, j+1): W_list[k][i, j] for i in range(r) for j in range(c)}
        b_ini = {i+1: b_list[k][i] for i in range(r)}
        m.add_component(f"W_{k}", Param(J, I, initialize=W_ini))
        m.add_component(f"b_{k}", Param(J, initialize=b_ini))
        Wp = getattr(m, f"W_{k}")
        bp = getattr(m, f"b_{k}")

        # pre-activation bounds (mutable Params to reflect per-neuron box)
        # (we index on I for convenience; |I| == |J| == r; we read with the same index j)
        lb_ini = {j+1: float(lbs[k][0, j]) for j in range(out_list[k])}
        ub_ini = {j+1: float(ubs[k][0, j]) for j in range(out_list[k])}
        m.add_component(f"lbz_{k}", Param(J, initialize=lb_ini, mutable=True))
        m.add_component(f"ubz_{k}", Param(J, initialize=ub_ini, mutable=True))
        lbz = getattr(m, f"lbz_{k}")
        ubz = getattr(m, f"ubz_{k}")

        # wiring (affine)
        zhat_old = m.input if k == 0 else getattr(m, f"zhat_{k-1}")
        for j in J:
            m.constraints.add(z[j] == sum(Wp[j, i] * zhat_old[i] for i in I) + bp[j])

        # (2) named bound constraints in canonical ≤0 form, so duals are nonnegative
        m.add_component(f"z_le_{k}", Constraint(J))  # z_j - ubz_j <= 0
        m.add_component(f"z_ge_{k}", Constraint(J))  # lbz_j - z_j <= 0
        z_le = getattr(m, f"z_le_{k}")
        z_ge = getattr(m, f"z_ge_{k}")
        for j in J:
            z_le[j] = z[j] - ubz[j] <= 0
            z_ge[j] = lbz[j] - z[j] <= 0

        # complementarity-style split variables (s1,s2), and phase fixing if in I_p/I_n
        for j in J:
            if (k, j) in I_p:
                # ON: zhat = z
                m.constraints.add(zhat[j] == z[j])
            elif (k, j) in I_n:
                # OFF: zhat = 0
                m.constraints.add(zhat[j] == 0)
            else:
                # introduce s1,s2 >= 0 with z = s1 - s2, zhat = s1, and s1*s2 <= eps
                # initialize using local bounds if we have them
                #lbj = float(value(lbz[j])); ubj = float(value(ubz[j]))

                m.add_component(f"s1_{k}_{j}", Var(domain=NonNegativeReals))
                m.add_component(f"s2_{k}_{j}", Var(domain=NonNegativeReals))
                s1 = getattr(m, f"s1_{k}_{j}")
                s2 = getattr(m, f"s2_{k}_{j}")

                # --- named complementarity constraints ---
                # s1 * s2 <= eps     (soft complementarity)
                cname_comp = f"relu_comp_{k}_{j}"
                m.add_component(cname_comp, Constraint(expr=s1 * s2 <= eps))

                # zhat[j] == s1
                cname_link_hat = f"relu_link_hat_{k}_{j}"
                m.add_component(cname_link_hat, Constraint(expr=zhat[j] == s1))

                # z[j] == s1 - s2
                cname_link_z = f"relu_link_z_{k}_{j}"
                m.add_component(cname_link_z, Constraint(expr=z[j] == s1 - s2))


                '''m.add_component(f"s1_{k}_{j}", Var(domain=NonNegativeReals, initialize=ini_s1))
                m.add_component(f"s2_{k}_{j}", Var(domain=NonNegativeReals, initialize=ini_s2))
                s1 = getattr(m, f"s1_{k}_{j}")
                s2 = getattr(m, f"s2_{k}_{j}")

                # complementarity surrogate
                m.constraints.add(s1 * s2 <= eps)
                # link
                m.constraints.add(zhat[j] == s1)
                m.constraints.add(z[j] == s1 - s2)'''

    # output margin layer (k = L-1)
    k = L - 1
    J_last = getattr(m, f"J_{k-1}")  # pre-activation index set of last hidden

    W_o = {j+1: (float(W_list[k][label, j]) - float(W_list[k][attack, j]))
           for j in range(W_list[k].shape[1])}
    b_o = {1: (float(b_list[k][label]) - float(b_list[k][attack]))}

    # clamp the margin zo by incumbent bounds (optional tightening)
    lb_o = {1: float( lbs[k][0]  )}
    ub_o = {1: float(ubs[k][0]   )}

    m.add_component("lbo", Param(RangeSet(1, 1), initialize=lb_o))
    m.add_component("ubo", Param(RangeSet(1, 1), initialize=ub_o))
    m.add_component("Wo",  Param(RangeSet(1, W_list[k].shape[1]), initialize=W_o))
    m.add_component("bo",  Param(RangeSet(1, 1), initialize=b_o))
    m.add_component("zo",  Var(RangeSet(1, 1), domain=Reals))
    zo  = m.zo; Wo = m.Wo; bo = m.bo; lbo = m.lbo; ubo = m.ubo

    m.constraints.add(zo[1] == sum(Wo[j] * getattr(m, f"zhat_{k-1}")[j] for j in J_last) + bo[1])
    m.constraints.add(zo[1] <= ubo[1])
    m.constraints.add(zo[1] >= lbo[1])

    # objective
    m.obj = Objective(expr=zo[1], sense=minimize)

    # ------------------ solve ------------------
    opt = SolverFactory('ipopt')
    opt.options["max_cpu_time"] = 900
    start = time.time()
    solver_status = opt.solve(m, tee=True)
    runtime = time.time() - start
    tc = solver_status.solver.termination_condition
 
    # classify termination
    if tc in (
        TerminationCondition.infeasible,
        TerminationCondition.infeasibleOrUnbounded,
        TerminationCondition.maxTimeLimit,
        TerminationCondition.maxIterations,
        TerminationCondition.minStepLength,
        TerminationCondition.noSolution,
        TerminationCondition.error
    ):
        msg = getattr(solver_status.solver, "message", "")
        print(f"[{k}] Skip: {tc.name}. {msg}")
        if tc == TerminationCondition.infeasible:
            return 'infeasible', u_best, runtime, m, {'z': {}}
        elif tc == TerminationCondition.maxTimeLimit:
            return 'maxtimelimit', u_best, runtime, m, {'z': {}}
        else:
            return 'failed', u_best, runtime, m, {'z': {}}

    if tc in (TerminationCondition.optimal,
              TerminationCondition.globallyOptimal,
              TerminationCondition.feasible):
        print("Problem solved successfully.")
        fmin = float(value(m.obj))
        nlp_info = _pack_nlp_info(m, L, layer_keys=None)
        return 'solvable', fmin, runtime, m, nlp_info

    # Fallback
    print("The solver failed to solve the problem.")
    return 'failed', u_best, runtime, m,  {'z': {}}

import math 

def split_lA_case_by_angle( lA_dict, case_idx=0, spec_idx=0,
                           use_abs=True, threshold=math.pi / 8):
    """
    Split entries of lA for a single case into overset/underset by angle > pi/4.

    Args:
        lA_dict (dict[str, torch.Tensor]):
            e.g. {'/8': tensor[B, S, ...], '/10': tensor[B, S, ...], ...}
        case_idx (int): which batch index to use.
        spec_idx (int): which spec index to use (usually 0 if single-spec).
        use_abs (bool): if True, use |slope|; angle in [0, pi/2).
        threshold (float): angle threshold in radians (default pi/4).

    Returns:
        overset (list[tuple[int, int]]): list of (l, j) with angle > threshold
        underset (list[tuple[int, int]]): list of (l, j) with angle < threshold
        # equality is ignored
    """
    overset = []
    underset = []

    # deterministic layer numbering: l = 0,1,2,... following sorted names
    layer_names =  lA_dict.keys() 

    for l_idx, layer_name in enumerate(layer_names):
        v = lA_dict[layer_name]  # tensor with shape [B, S, ...]
        if v.ndim < 3:
            raise ValueError(f"Expected lA[{layer_name}] to have at least 3 dims [B,S,...], "
                             f"got shape {tuple(v.shape)}")

        if case_idx >= v.shape[0] or spec_idx >= v.shape[1]:
            # skip if requested case/spec is out of range for this layer
            continue

        # pick one case + one spec: shape [...layer_dims...]
        one = v[case_idx, spec_idx]  # e.g. [n_hidden] or [C,H,W]
        slopes = one.abs() if use_abs else one
        angles = torch.atan(slopes)  # radians

        flat_angles = angles.reshape(-1)
        for j, angle in enumerate(flat_angles):
            a = angle.item()
            if a > threshold:
                overset.append((l_idx, j))
            elif a < threshold:
                underset.append((l_idx, j))
            # if a == threshold, ignore (neither overset nor underset)

    return overset, underset



def update_I_sets_from_history_list(I_p_old, I_n_old, I_p, I_n, I0, history_list):
    """
    Update I_p, I_n, I0 using a list of history dicts.

    history_list: list of dicts, each of form
        h[layer_name] = (indices, directions, biases, _, _)

    I_p, I_n, I0 are sets of (l_idx, j):
        l_idx: integer layer index (0,1,2,...) from sorted layer names
        j:     neuron index in that layer
    """
    # Copy to avoid in-place modification
    I_p = set(I_p)
    I_n = set(I_n)
    I0  = set(I0)

    if not history_list:
        return I_p, I_n, I0

    # Build a consistent layer-name -> index map from all histories
    all_layer_names = set()
    for h in history_list:
        all_layer_names.update(h.keys())
    layer_names = sorted(all_layer_names)
    name_to_lidx = {name: l for l, name in enumerate(layer_names)}

    for hist in history_list:
        for layer_name, record in hist.items():
            if not record or len(record) < 2:
                continue

            idxs, dirs = record[0], record[1]

            # Skip layers with no splits
            if idxs is None or dirs is None:
                continue
            if isinstance(idxs, (list, tuple)) and len(idxs) == 0:
                continue
            if isinstance(idxs, torch.Tensor) and idxs.numel() == 0:
                continue

            # Convert to 1D tensors
            if not isinstance(idxs, torch.Tensor):
                idxs = torch.tensor(idxs, dtype=torch.long)
            if not isinstance(dirs, torch.Tensor):
                dirs = torch.tensor(dirs, dtype=torch.float32)

            idxs = idxs.view(-1)
            dirs = dirs.view(-1)

            l_idx = name_to_lidx[layer_name]

            for j_idx, d in zip(idxs.tolist(), dirs.tolist()):
                key = (l_idx, int(j_idx)+1)

                if d > 0:      # positive side
                    I_p.add(key)
                    I_n.discard(key)
                    I0.discard(key)
                elif d < 0:    # negative side
                    I_n.add(key)
                    I_p.discard(key)
                    I0.discard(key)
                # d == 0: keep previous status
    changed_pos  = list(I_p  - set(I_p_old)) 
    changed_neg =  list(I_n - set(I_n_old))
    return changed_pos, changed_neg, list(I_p), list(I_n), list(I0)

def NLP_solver_output_warm(
        u_best, over_set, under_set,
        model, lb, ub, spec, lbs, ubs, I_p, I_n, I0,
        prev_nlp_info=None,          # <-- warm-start info from previous NLP
        eps=1e-5, warm_eps=1e-2,     # <-- warm_eps: threshold for using prev s1/s2
        layer_keys=None
    ):
    """
    Same interface as NLP_solver_output, but warm-starts s1/s2 from prev_nlp_info.

    prev_nlp_info should be the dict returned by _pack_nlp_info(...):
        prev_nlp_info['s1'][key], prev_nlp_info['s2'][key] are (1, n_j) tensors.
    """

    # ------------------ unpack model structure ------------------
    W_list, b_list, in_list, out_list = model_info(model)
    label = int((spec[0][0] == 1).nonzero(as_tuple=False))
    attack = int((spec[0][0] == -1).nonzero(as_tuple=False))
    L = len(W_list)                      # number of layers incl. output affine
    input_size = in_list[0]

    if layer_keys is None:
        layer_keys = list(range(L - 1))

    # ------------------ build Pyomo model ------------------
    m = ConcreteModel()
    m.dual = Suffix(direction=Suffix.IMPORT)

    # input bounds (NEW ones)
    lb_numpy, ub_numpy = lb.cpu().numpy(), ub.cpu().numpy()
    input_bounds = {i + 1: (float(lb_numpy[0, i]), float(ub_numpy[0, i]))
                    for i in range(input_size)}

    m.constraints = ConstraintList()
    m.input = Var(RangeSet(1, input_size), domain=Reals, bounds=input_bounds)

    # hidden/pre-activation layers: k = 0..L-2
    for k in range(L - 1):
        m.add_component(f"I_{k}", RangeSet(1, in_list[k]))
        m.add_component(f"J_{k}", RangeSet(1, out_list[k]))
        I = getattr(m, f"I_{k}")
        J = getattr(m, f"J_{k}")

        # variables for pre- and post-ReLU
        m.add_component(f"z_{k}",    Var(J, domain=Reals))
        m.add_component(f"zhat_{k}", Var(J, domain=NonNegativeReals))
        z    = getattr(m, f"z_{k}")
        zhat = getattr(m, f"zhat_{k}")

        # parameters W, b
        r = W_list[k].shape[0]
        c = W_list[k].shape[1]
        W_ini = {(i + 1, j + 1): W_list[k][i, j] for i in range(r) for j in range(c)}
        b_ini = {i + 1: b_list[k][i] for i in range(r)}
        m.add_component(f"W_{k}", Param(J, I, initialize=W_ini))
        m.add_component(f"b_{k}", Param(J, initialize=b_ini))
        Wp = getattr(m, f"W_{k}")
        bp = getattr(m, f"b_{k}")

        # NEW bounds (mutable Params)
        lb_ini = {j + 1: float(lbs[k][0, j]) for j in range(out_list[k])}
        ub_ini = {j + 1: float(ubs[k][0, j]) for j in range(out_list[k])}
        m.add_component(f"lbz_{k}", Param(J, initialize=lb_ini, mutable=True))
        m.add_component(f"ubz_{k}", Param(J, initialize=ub_ini, mutable=True))
        lbz = getattr(m, f"lbz_{k}")
        ubz = getattr(m, f"ubz_{k}")

        # wiring
        zhat_old = m.input if k == 0 else getattr(m, f"zhat_{k-1}")
        for j in J:
            m.constraints.add(z[j] == sum(Wp[j, i] * zhat_old[i] for i in I) + bp[j])

        # bound constraints (canonical ≤ 0)
        m.add_component(f"z_le_{k}", Constraint(J))  # z_j - ubz_j <= 0
        m.add_component(f"z_ge_{k}", Constraint(J))  # lbz_j - z_j <= 0
        z_le = getattr(m, f"z_le_{k}")
        z_ge = getattr(m, f"z_ge_{k}")
        for j in J:
            z_le[j] = z[j] - ubz[j] <= 0
            z_ge[j] = lbz[j] - z[j] <= 0

        # fetch previous warm-start arrays for this layer (if any)
        lay_key = layer_keys[k] if (layer_keys is not None and k < len(layer_keys)) else k
        prev_s1_layer = None
        prev_s2_layer = None
        if prev_nlp_info is not None and 's1' in prev_nlp_info and lay_key in prev_nlp_info['s1']:
            prev_s1_layer = prev_nlp_info['s1'][lay_key].reshape(-1).cpu().numpy()
            prev_s2_layer = prev_nlp_info['s2'][lay_key].reshape(-1).cpu().numpy()

        # complementarity-style split variables (s1,s2), and phase fixing
        for j in J:
            if (k, j) in I_p:
                m.constraints.add(zhat[j] == z[j])   # ON
            elif (k, j) in I_n:
                m.constraints.add(zhat[j] == 0)      # OFF
            else:
                # baseline initialization from local bounds
                lbj = float(value(lbz[j]))
                ubj = float(value(ubz[j]))
                ini_s1 = 0.0
                ini_s2 = 0.0

                # overwrite with warm-start if available and "active"
                idx = j - 1  # zero-based index in prev_s1_layer
                if prev_s1_layer is not None and prev_s2_layer is not None:
                    s1_prev = float(prev_s1_layer[idx])
                    s2_prev = float(prev_s2_layer[idx])
                    if (abs(s1_prev) >= warm_eps) or (abs(s2_prev) >= warm_eps):
                        ini_s1 = max(0.0, s1_prev)
                        ini_s2 = max(0.0, s2_prev)

                m.add_component(f"s1_{k}_{j}", Var(domain=NonNegativeReals, initialize=ini_s1))
                m.add_component(f"s2_{k}_{j}", Var(domain=NonNegativeReals, initialize=ini_s2))
                s1 = getattr(m, f"s1_{k}_{j}")
                s2 = getattr(m, f"s2_{k}_{j}")

                # complementarity surrogate
                m.constraints.add(s1 * s2 <= eps)
                # link
                m.constraints.add(zhat[j] == s1)
                m.constraints.add(z[j] == s1 - s2)

    # ------------------ output margin layer (k = L-1) ------------------
    k = L - 1
    J_last = getattr(m, f"J_{k - 1}")

    W_o = {j + 1: (float(W_list[k][label, j]) - float(W_list[k][attack, j]))
           for j in range(W_list[k].shape[1])}
    b_o = {1: (float(b_list[k][label]) - float(b_list[k][attack]))}

    lb_o = {1: float(lbs[k][0])}
    ub_o = {1: float(ubs[k][0])}

    m.add_component("lbo", Param(RangeSet(1, 1), initialize=lb_o, mutable=True))
    m.add_component("ubo", Param(RangeSet(1, 1), initialize=ub_o, mutable=True))
    m.add_component("Wo",  Param(RangeSet(1, W_list[k].shape[1]), initialize=W_o))
    m.add_component("bo",  Param(RangeSet(1, 1), initialize=b_o))
    m.add_component("zo",  Var(RangeSet(1, 1), domain=Reals))
    zo  = m.zo; Wo = m.Wo; bo = m.bo; lbo = m.lbo; ubo = m.ubo

    m.constraints.add(zo[1] == sum(Wo[j] * getattr(m, f"zhat_{k - 1}")[j] for j in J_last) + bo[1])
    m.constraints.add(zo[1] <= ubo[1])
    m.constraints.add(zo[1] >= lbo[1])

    # objective
    m.obj = Objective(expr=zo[1], sense=minimize)

    # ------------------ solve with Ipopt warm-start ------------------
    opt = SolverFactory('ipopt')
    opt.options["max_cpu_time"] = 900
    opt.options["warm_start_init_point"] = "yes"  # tell Ipopt to use initial values

    start = time.time()
    solver_status = opt.solve(m, tee=True)
    runtime = time.time() - start
    tc = solver_status.solver.termination_condition

    # ------------------ classify termination ------------------
    if tc in (
        TerminationCondition.infeasible,
        TerminationCondition.infeasibleOrUnbounded,
        TerminationCondition.maxTimeLimit,
        TerminationCondition.maxIterations,
        TerminationCondition.minStepLength,
        TerminationCondition.noSolution,
        TerminationCondition.error
    ):
        msg = getattr(solver_status.solver, "message", "")
        print(f"[{k}] Warm-start NLP: {tc.name}. {msg}")
        if tc == TerminationCondition.infeasible:
            return 'infeasible', u_best, runtime, m, {'z': {}}
        elif tc == TerminationCondition.maxTimeLimit:
            return 'maxtimelimit', u_best, runtime, m, {'z': {}}
        else:
            return 'failed', u_best, runtime, m, {'z': {}}

    if tc in (TerminationCondition.optimal,
              TerminationCondition.globallyOptimal,
              TerminationCondition.feasible):
        print("Warm-start NLP solved successfully.")
        fmin = float(value(m.obj))
        nlp_info = _pack_nlp_info(m, L, layer_keys=layer_keys)
        return 'solvable', fmin, runtime, m, nlp_info

    print("Warm-start NLP: solver failed.")
    return 'failed', None, None, m, {'z': {}}

def _apply_exact_branch(m, k:int, j:int, branch:str, eps_nudge:float=0.0):
    z    = getattr(m, f"z_{k}")[j]
    zhat = getattr(m, f"zhat_{k}")[j]
    if zhat is None:
        raise RuntimeError(f"Layer {k} is output; no ReLU to split.")

    # Turn off undecided constraints if they exist
    for nm in (f"relu_comp_{k}_{j}", f"relu_link_hat_{k}_{j}", f"relu_link_z_{k}_{j}"):
        if hasattr(m, nm):
            getattr(m, nm).deactivate()

    # Add/activate exact linear regime
    if branch == 'positive':
        cname = f"relu_exact_pos_{k}_{j}"
        if hasattr(m, cname):
            getattr(m, cname).activate()
        else:
            m.add_component(cname, Constraint(expr= z - zhat == 0 ))
        zv = 0.0 if z.value is None else float(value(z))
        zhat.set_value(max(0.0, zv - eps_nudge))

    elif branch == 'negative':
        c1, c2 = f"relu_exact_neg_hat_{k}_{j}", f"relu_exact_neg_zle0_{k}_{j}"
        if hasattr(m, c1): getattr(m, c1).activate()
        else:              m.add_component(c1, Constraint(expr= zhat == 0 ))
        if hasattr(m, c2): getattr(m, c2).activate()
        else:              m.add_component(c2, Constraint(expr= z <= 0 ))
        zhat.set_value(0.0)
        if z.value is not None and float(value(z)) > 0:
            z.set_value(0.0)
    else:
        raise ValueError("branch must be 'positive' or 'negative'")

def _ensure_warmstart_suffixes(m):
    if not hasattr(m, 'dual'):        m.dual = Suffix(direction=Suffix.IMPORT_EXPORT)
    if not hasattr(m, 'ipopt_zL_in'): m.ipopt_zL_in = Suffix(direction=Suffix.IMPORT_EXPORT)
    if not hasattr(m, 'ipopt_zU_in'): m.ipopt_zU_in = Suffix(direction=Suffix.IMPORT_EXPORT)

def _cache_state(m) -> Dict[str, Dict[str, float]]:
    """Primal + bound multipliers (robust to structural edits & deep copies)."""
    state = {
        'x':  {v.name: (0.0 if v.value is None else float(value(v)))
               for v in m.component_data_objects(Var)},
        'zL': {v.name: (m.ipopt_zL_in[v] if hasattr(m,'ipopt_zL_in') and v in m.ipopt_zL_in else 0.0)
               for v in m.component_data_objects(Var)},
        'zU': {v.name: (m.ipopt_zU_in[v] if hasattr(m,'ipopt_zU_in') and v in m.ipopt_zU_in else 0.0)
               for v in m.component_data_objects(Var)}
    }
    return state

def _load_state(m, state, skip_var_names=None):
    """
    Load warm-start state into model m, projecting into current bounds.

    skip_var_names: optional set of v.name for which we DO NOT warm-start
                    (used for (k,j) that changed phase).
    """
    _ensure_warmstart_suffixes(m)
    skip_var_names = set(skip_var_names or [])

    # Clear old suffix info to avoid stale keys
    m.ipopt_zL_in.clear()
    m.ipopt_zU_in.clear()

    for v in m.component_data_objects(Var):
        name = v.name
        if name in skip_var_names:
            continue

        # ---- Primal warm start ----
        if name in state['x']:
            val = state['x'][name]

            # Project into [lb, ub] to respect NonNegativeReals, etc.
            lb = v.lb
            ub = v.ub
            if lb is not None and val < lb:
                val = lb
            if ub is not None and val > ub:
                val = ub

            try:
                v.set_value(val)
            except:
                # Fallback: 0 or lower bound
                v.set_value(lb if lb is not None and lb > 0 else 0.0)

        # ---- Bound multipliers warm start ----
        if name in state['zL']:
            m.ipopt_zL_in[v] = state['zL'][name]
        if name in state['zU']:
            m.ipopt_zU_in[v] = state['zU'][name]

 
 

def _ipopt(warm=True, it=40, tol=1e-6, mu=1e-4, exe=None):
    opt = SolverFactory('ipopt', executable=exe) if exe else SolverFactory('ipopt')
    opt.options.update({
        'warm_start_init_point': 'yes' if warm else 'no',
        'warm_start_bound_push': 1e-4,
        'warm_start_slack_bound_frac': 1e-2,
        'warm_start_slack_bound_push': 1e-4,
        'mu_init': mu,
        'tol': tol,
        'max_iter': it,
        'print_level': 0
    })
    return opt

def resolve_with_phase_updates(u_best,
    m_prev,
    pos_neurons,          # iterable of (k, j) that go to I_p (positive)
    neg_neurons,          # iterable of (k, j) that go to I_n (negative)
    ipopt_path: Optional[str] = None,
    it: int = 40,
    tol: float = 1e-6,
    mu: float = 1e-4
):
    """
    Warm-re-solve NLP after moving some neurons from I0 -> I_p / I_n.

    Steps:
      1) deep-copy m_prev
      2) tighten lbz/ubz for (k,j) in pos_neurons / neg_neurons
      3) warm-start all vars except those for changed neurons
      4) apply exact phase constraints for changed neurons
      5) solve with Ipopt warm start
    """
    # 1) cache state from previous model
    _ensure_warmstart_suffixes(m_prev)
    state = _cache_state(m_prev)

    # 2) deep-copy model
    m = copy.deepcopy(m_prev)
    _ensure_warmstart_suffixes(m)

    # 3) tighten pre-activation bounds according to new phases
    pos_neurons = list(pos_neurons)
    neg_neurons = list(neg_neurons)
    _update_bounds_for_phases(m, pos_neurons, neg_neurons)

    # 4) skip warm-start for variables of changed neurons
    skip_vars = set()
    for (k, j) in pos_neurons + neg_neurons:
        skip_vars.add(f"z_{k}[{j}]")
        skip_vars.add(f"zhat_{k}[{j}]")
        skip_vars.add(f"s1_{k}_{j}")
        skip_vars.add(f"s2_{k}_{j}")

    # 5) load warm-start state (projected into new bounds)
    _load_state(m, state, skip_var_names=skip_vars)

    # 6) apply exact branch constraints for changed neurons
    for (k, j) in pos_neurons:
        _apply_exact_branch(m, k, j, branch='positive')
    for (k, j) in neg_neurons:
        _apply_exact_branch(m, k, j, branch='negative')

    # 7) solve with Ipopt warm-start
    opt = _ipopt(warm=True, it=it, tol=tol, mu=mu, exe=ipopt_path)
    t0 = time.time()
    res = opt.solve(m, tee=False)
    runtime = time.time() - t0

    tc = res.solver.termination_condition
    ok = tc in {
        TerminationCondition.optimal,
        TerminationCondition.locallyOptimal,
        TerminationCondition.feasible,
    }

    if ok:
        obj = float(value(m.obj))
        status = 'solvable'
    else:
        print(f"[resolve_with_phase_updates] Ipopt termination: {tc}")
        obj = u_best
        status = f'infeasible ({tc})'

    return status, obj, runtime, m

def resolve_with_phase_updates_old(u_best,
    m_prev,
    pos_neurons,
    neg_neurons,
    ipopt_path=None,
    it=40,
    tol=1e-6,
    mu=1e-4
):
    _ensure_warmstart_suffixes(m_prev)
    state = _cache_state(m_prev)

    m = copy.deepcopy(m_prev)
    _ensure_warmstart_suffixes(m)

    # collect variable names to skip warm-start for changed neurons
    skip_vars = set()
    for (k, j) in list(pos_neurons) + list(neg_neurons):
        skip_vars.add(f"z_{k}[{j}]")
        skip_vars.add(f"zhat_{k}[{j}]")
        skip_vars.add(f"s1_{k}_{j}")
        skip_vars.add(f"s2_{k}_{j}")

    _load_state(m, state, skip_var_names=skip_vars)

    for (k, j) in pos_neurons:
        _apply_exact_branch(m, k, j, branch='positive')
    for (k, j) in neg_neurons:
        _apply_exact_branch(m, k, j, branch='negative')

    opt = _ipopt(warm=True, it=it, tol=tol, mu=mu, exe=ipopt_path)
    t0 = time.time()
    res = opt.solve(m, tee=False)
    runtime = time.time() - t0

    tc = res.solver.termination_condition
    ok = tc in {
        TerminationCondition.optimal,
        TerminationCondition.locallyOptimal,
        TerminationCondition.feasible,
    }

    if ok:
        return 'solvable', float(value(m.obj)), runtime, m
    else:
        print(f"[resolve_with_phase_updates] Ipopt termination: {tc}")
        return 'failed', u_best, runtime, m

def resolve_with_phase_updates_old(u_best,
    m_prev,
    pos_neurons,          # iterable of (k, j) that go to I_p (positive)
    neg_neurons,          # iterable of (k, j) that go to I_n (negative)
    ipopt_path: Optional[str] = None,
    it: int = 40,
    tol: float = 1e-6,
    mu: float = 1e-4
):
    """
    Warm-re-solve NLP after moving some neurons from I0 -> I_p / I_n.

    Args:
        m_prev:      previously solved Pyomo model
        pos_neurons: list/set of (k,j) to fix active (I_p)
        neg_neurons: list/set of (k,j) to fix inactive (I_n)
    Returns:
        status: 'solvable' | 'failed' | ...
        obj:    objective value (or None)
        runtime: solve time
        m_new:  new Pyomo model with updated constraints
    """
    # 1) Ensure warm-start suffixes exist on the previous model, then cache state
    _ensure_warmstart_suffixes(m_prev)
    state = _cache_state(m_prev)

    # 2) Deep copy the model to avoid mutating parent node
    m = copy.deepcopy(m_prev)
    _ensure_warmstart_suffixes(m)

    # 3) Load warm-start values & bound multipliers
    _load_state(m, state)

    # 4) Apply phase updates
    for (k, j) in pos_neurons:
        _apply_exact_branch(m, k, j, branch='positive')
    for (k, j) in neg_neurons:
        _apply_exact_branch(m, k, j, branch='negative')

    # 5) Warm-start Ipopt
    opt = _ipopt(warm=True, it=it, tol=tol, mu=mu)

    t0 = time.time()
    res = opt.solve(m, tee=False)
    runtime = time.time() - t0

    tc = res.solver.termination_condition
    ok = tc in (
        TerminationCondition.optimal,
        TerminationCondition.locallyOptimal,
        TerminationCondition.feasible,
    )

    if ok:
        obj = float(value(m.obj))
        status = 'solvable'
    else:
        obj = u_best
        status = f'failed ({tc})'

    return status, obj, runtime, m


def general_bab(  model, lb, ub,spec, net, domain, x, refined_lower_bounds=None,
                refined_upper_bounds=None, activation_opt_params=None,
                reference_alphas=None, reference_lA=None, attack_images=None,
                timeout=None, max_iterations=None, refined_betas=None, rhs=0,
                model_incomplete=None, time_stamp=0, property_idx=None, earlystop = True, eps_s = 0.01):

    start_time = time.time()
    stats = Stats() 
    solver_args = arguments.Config['solver']
    varepsilon = solver_args['NLP']['varepsilon'] 
    num_t = solver_args['NLP']['num_t'] 
    bab_args = arguments.Config['bab']
    branch_args = bab_args['branching']
    timeout = timeout or bab_args['timeout']
    max_domains = bab_args['max_domains']
    batch = solver_args['batch_size']
    cut_enabled = bab_args['cut']['enabled']
    biccos_args = bab_args['cut']['biccos']
    multi_tree_branching_enabled = (cut_enabled and biccos_args['enabled'] and
                                    biccos_args['multi_tree_branching']['enabled'])
    max_iterations = max_iterations or bab_args['max_iterations']

    input_relu_splitter = (InputReluSplitter() if
                branch_args['branching_input_and_activation'] else None)
    #  initial NLP
    u_best = float('inf')   
    u_best_list = []  
    u_NLP = u_best
    total_round = 0
    result = None  
    global_lb = -torch.inf
    W_list, _,_,_  = model_info(model) 
    L = len(W_list)  
    wait = 0
    lb_list =[]
    cum_time = 0

    if not isinstance(rhs, torch.Tensor):
        rhs = torch.tensor(rhs)
    stop_criterion = stop_criterion_batch_any(rhs)

    if refined_lower_bounds is None or refined_upper_bounds is None:
        assert arguments.Config['general']['enable_incomplete_verification'] is False
        global_lb, ret = net.build(
            domain, x, stop_criterion_func=stop_criterion, decision_thresh=rhs)
        updated_mask, lA, alpha = (ret['mask'], ret['lA'], ret['alphas'])
        global_ub = global_lb + torch.inf
    else:
        print('build refined bounds')
        ret = net.build_with_refined_bounds(
            domain, x, refined_lower_bounds, refined_upper_bounds,
            activation_opt_params, reference_lA=reference_lA,
            reference_alphas=reference_alphas, stop_criterion_func=stop_criterion,
            cutter=net.cutter, refined_betas=refined_betas, decision_thresh=rhs)
        (global_ub, global_lb, updated_mask, lA, alpha) = (
            ret['global_ub'], ret['global_lb'], ret['mask'], ret['lA'],
            ret['alphas']) 
        
        # release some storage to save memory
        if activation_opt_params is not None: del activation_opt_params
        torch.cuda.empty_cache()   
    # Transfer A_saved to the new LiRPANet
    if hasattr(model_incomplete, 'A_saved'):
        net.A_saved = model_incomplete.A_saved

    if cut_enabled:
        net.set_cuts(model_incomplete.A_saved, x, ret['lower_bounds'], ret['upper_bounds'])
        if biccos_args['enabled']:
            print('Inferred cuts enabled')
            print('Warning: The mininal batch size ratio is set to 0')
            initial_bs_ratio = arguments.Config['solver']['min_batch_size_ratio']
            arguments.Config['solver']['min_batch_size_ratio'] = 0
            net.biccos = BICCOS(ret, property_idx, rhs)

    impl_params = get_impl_params(net, model_incomplete, time_stamp)

    if solver_args['beta-crown']['all_node_split_LP']:
        # Initialize the LP solver model and pre-store the names of the layers
        timeout = bab_args['timeout']
        net.build_solver_model(timeout, model_type='lp')
        net.pre_relu_layer_names = [relu_layer.inputs[0].name for relu_layer in net.net.relus]
        net.relu_layer_names = [relu_layer.name for relu_layer in net.net.relus]
        input_name = [name for name in net.net.input_name if type(net.net[name]) == BoundInput]
        assert len(input_name) == 1, 'there should be only 1 BoundInput!'
        input_name = input_name[0]
        def extract_var_names(solver_vars):
            if isinstance(solver_vars, list):
                return [extract_var_names(sub_solver_vars) for sub_solver_vars in solver_vars]
            else:
                return solver_vars.VarName
        net.input_name = extract_var_names(net.net[input_name].solver_vars)
    # tell the AutoLiRPA class not to transfer intermediate bounds to save time
    net.interm_transfer = bab_args['interm_transfer']
    if not bab_args['interm_transfer']:
        # Branching domains cannot support
        # bab_args['interm_transfer'] == False and bab_args['sort_domain_interval'] > 0
        # at the same time.
        assert bab_args['sort_domain_interval'] == -1

    net.pre_relu_layer_names = [relu_layer.inputs[0].name for relu_layer in net.net.relus]
    net.relu_layer_names = [relu_layer.name for relu_layer in net.net.relus]
         
    
    # If we are not optimizing intermediate layer bounds, we do not need to
    # save all the intermediate alpha.
    # We only keep the alpha for the last layer.
    if not solver_args['beta-crown']['enable_opt_interm_bounds']:
        # new_alpha shape:
        # [dict[relu_layer_name, {final_layer: torch.tensor storing alpha}]
        # for each sample in batch]
        alpha = prune_alphas(alpha, net.alpha_start_nodes)

    if bab_args['attack']['enabled']:
        DomainClass = SortedReLUDomainList
    elif multi_tree_branching_enabled:
        DomainClass = ShallowFirstBatchedDomainList
    else:
        DomainClass = BatchedDomainList
    # initial NLP solution
    time_nlp_list = []
    status_list = []
    ub_list = []
    start_time = time.time()
    I_p, I_n, I0, lbs, ubs,=classify_neurons_from_bounds(ret["lower_bounds"], ret["upper_bounds"],  batch_idx = 0 ) 
    I_p_old , I_n_old = I_p.copy(), I_n.copy() 
    u_best = float('inf')
    status0,  u_best, runtimeNLP, model_nlp, nlp_info =  NLP_solver_output(u_best, I_p.copy(), I_n.copy(),   model, lb, ub,spec, lbs, ubs, I_p.copy(), I_n.copy(), I0.copy() ,  eps = 1e-5,   layer_keys = net.pre_relu_layer_names) 
    cum_time +=  runtimeNLP  
    time_nlp_list.append(runtimeNLP)
    status_list.append(status0)
    ub_list.append(u_best)
    rhs = torch.tensor(u_best) 
    all_label_global_lb = torch.min(global_lb - rhs).item()
    all_label_global_ub = torch.max(global_ub - rhs).item()
    if arguments.Config['debug']['lp_test'] in ['LP', 'MIP']:
        return all_label_global_lb, 0, 'unknown', stats,u_best  , cum_time, lb_list, u_best_list 

    if stop_criterion(global_lb).all():
        return all_label_global_lb, 0, 'safe', stats,u_best  , cum_time, lb_list, u_best_list 
    # This is the first (initial) domain.
    domains = DomainClass(nlp_info,
        ret, lA, global_lb, global_ub, alpha,
        copy.deepcopy(ret['history']), rhs, net=net, x=x,
        branching_input_and_activation=branch_args['branching_input_and_activation'])
    num_domains = len(domains)

    # after domains are added, we replace global_lb, global_ub with the multile
    # targets 'real' global lb and ub to make them scalars
    global_lb, global_ub = all_label_global_lb, all_label_global_ub
    updated_mask, tot_ambi_nodes = get_unstable_neurons(updated_mask, net)
    net.tot_ambi_nodes = tot_ambi_nodes 

    if cut_enabled and impl_params is None:
        cut_verification(net, domains)

    if bab_args['attack']['enabled']:
        return bab_loop_attack(
            domains, net, batch, rhs, start_time, timeout,
            updated_mask, attack_images, all_label_global_ub)

    branching_heuristic = get_branching_heuristic(net, 'kfsb')

    # If we are using shallow branching, we need to do the multi-tree search
    # as the pre-solve part for BICCOS.
    if isinstance(domains, ShallowFirstBatchedDomainList):
        multi_tree_bab(
            net, domains, batch, stop_criterion, biccos_args,
            impl_params, stats, start_time)

    num_domains = len(domains)
    vram_ratio = 0.85 if cut_enabled else 0.9
    auto_batch_size = AutoBatchSize(
        batch, net.device, vram_ratio,
        enable=arguments.Config['solver']['auto_enlarge_batch_size'])  
    while ((global_lb + varepsilon ) < u_best and num_domains > 0 and (total_round < max_iterations )):
        start_round = time.time()
        total_round += 1 
        print(f'BaB round {total_round}')
        print("gap is ", (u_best - global_lb   ))
        wait += 1 
        if (cut_enabled and biccos_args['enabled']
            and total_round - 1 == net.biccos.max_iter):
            print('Cut Inference reaches max iterations. Recover the setting')
            arguments.Config['solver']['min_batch_size_ratio'] = initial_bs_ratio  # pylint: disable=used-before-assignment

        auto_batch_size.record_actual_batch_size(min(batch, len(domains)))
        if input_relu_splitter and input_relu_splitter.split_condition(
                total_round-1):
            global_lb = input_split_on_relu_domains(
                domains, wrapped_net=net, batch_size=batch)
        else:
            if (bab_args['cut']['enabled'] and bab_args['cut']['cplex_cuts']
                and not biccos_args['enabled']):
                fetch_cut_from_cplex(net)
            global_lb, alphas,  history= act_split_round(nlp_info,
                domains, net, batch, iter_idx=total_round,
                impl_params=impl_params, stats=stats,
                branching_heuristic=branching_heuristic)
            lb_list.append(global_lb.item() )
        batch = check_auto_enlarge_batch_size(auto_batch_size)   
        betaTime = time.time() - start_round
        cum_time +=  betaTime  
        ini = False  
        
        if wait >  num_t: 
            wait = 0 
            changed_pos, changed_neg, I_p1, I_n1, I01 = update_I_sets_from_history_list(I_p_old, I_n_old,  I_p.copy(), I_n.copy(), I0.copy(), history.copy()) 
            start_NLP = time.time()   
            #status0,  u_NLP, runtimeNLP, _, _ =  NLP_solver_output(u_best, I_p1.copy(), I_n1.copy(),   model, lb, ub,spec, lbs, ubs, I_p1.copy(), I_n1.copy(), I01.copy() ,  eps = 1e-5,   layer_keys = net.pre_relu_layer_names) 
            status0,  u_NLP, runtimeNLP, model_nlp  = resolve_with_phase_updates(u_best,  m_prev=model_nlp,  pos_neurons=changed_pos, neg_neurons=changed_neg, it=40,tol=1e-6,mu=1e-4)
            I_p_old , I_n_old = I_p1.copy(), I_n1.copy()  
            time_nlp_list.append(runtimeNLP)
            status_list.append(status0)
            ub_list.append(u_NLP)
            if u_NLP < u_best:
                u_best = u_NLP 
                u_best_list.append(u_best) 
                print('u_best', u_best) 
            NLPtime = time.time() - start_NLP 
            cum_time += NLPtime 
        if u_best < 0 and earlystop == True:
            print('Early Stop')
            return global_lb, stats.visited, 'unsafe', stats,u_best  , cum_time, lb_list, u_best_list
                  
        if isinstance(global_lb, torch.Tensor):
            global_lb = global_lb.max().item()

        num_domains = len(domains)
 

        if stats.all_node_split:
            if stats.all_split_result == 'unsafe' or u_best < 0:
                stats.all_node_split = False
                result = 'unsafe'
            else:
                stats.all_node_split = False
                result = 'unknown'
        elif num_domains > max_domains:
            print('Maximum number of visited domains has reached.')
            result = 'unknown'
        elif time.time() - start_time > timeout:
            print('Time out!!!!!!!!')
            result = 'unknown'  
        elif global_lb > 0:
            result = 'safe'
        else:
            result = 'unknown'
        print(f'Cumulative time: {cum_time}\n')
 
    if result is None:
        if u_best < 0:
            result = 'unsafe'
        elif global_lb > 0:
            result = 'safe'
        else:
            result = 'unknown'
            
    del domains
    clean_net_mps_process(net)
    print('beta_crown bounds list')
    print(lb_list )
    print('time_nlp_list')
    print(time_nlp_list)
    print('ub_list')
    print(ub_list)
    print('status_list')
    print(status_list)
    return global_lb, stats.visited, result, stats,u_best  , cum_time, lb_list, u_best_list 