#########################################################################  
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
import  random,os,time,gc,torch, socket,re
import numpy as np
import torch.nn as nn
from collections import defaultdict
from torch.utils.data import  TensorDataset 
import arguments
from auto_LiRPA import BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
from auto_LiRPA.utils import stop_criterion_all, stop_criterion_batch_any
from auto_LiRPA.operators.convolution import BoundConv
from beta_CROWN_solver import LiRPANet
from specifications import (trim_batch, batch_vnnlib, sort_targets, prune_by_idx, add_rhs_offset)  
#from bab_mip import model_info
from input_split.batch_branch_and_bound import input_bab_parallel 
from specifications import construct_vnnlib
from utils import  expand_path
from utils import Logger, print_model

from pyomo.environ import * 
import torch
import math
import matplotlib.pyplot as plt 
import seaborn as sns
import pandas as pd
import yaml
import time
import pickle
import os
import sys 
from typing import Dict, Any  
import sys, atexit
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Export results
f_out = open("stdout_cmp_eps1.log", "w")
f_err = open("stderr_cmp_eps1.log", "w") 
_stdout_orig, _stderr_orig = sys.stdout, sys.stderr
sys.stdout, sys.stderr = f_out, f_err

def _close_streams():
    # restore first (optional), then close
    sys.stdout.flush(); sys.stderr.flush()
    sys.stdout, sys.stderr = _stdout_orig, _stderr_orig
    f_out.close(); f_err.close()

atexit.register(_close_streams)

#Step 1: load model
class NoSoftmaxNet(nn.Module):
    #define layers of neural network
    def __init__(self, input_size, hidden_size = 256):
        super().__init__()
        self.hidden1  = nn.Linear(input_size, hidden_size)
        self.hidden2  = nn.Linear(hidden_size, hidden_size)
        self.output  = nn.Linear(hidden_size, 10)
        self.relu = nn.ReLU()

    #define forward pass of neural network
    def forward(self, x):
        x = self.hidden1(x)
        x = self.relu(x)
        x = self.hidden2(x)
        x = self.relu(x)
        x = self.output(x)
        return x
    
# Step 2: load data

def load_data_old( eps, data_path = './cifar_test.pt',  a = 4, size_image = 32 ): 
    x_test_tensor, y_test_tensor = torch.load(data_path) 
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
    X, labels = test_dataset[:][0].view(-1, size_image*size_image*3), test_dataset[:][1]    
    target_label  = torch.ones_like(labels)*a
    data_info ={}
    data_info["X"] = X
    data_info["labels"] = labels
    data_info["target_label"] = target_label
    data_info["eps"] = eps
    return data_info

def save_datasets(batch_size=128, num_workers=4, augmentation='none'):
    """
    Create CIFAR-10 data loaders with proper normalization and augmentation.
    
    Args:
        augmentation: 'none', 'basic', 'strong', or 'autoaugment'
    """
    # CIFAR-10 statistics
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2470, 0.2435, 0.2616]
    
    # Training transforms
    if augmentation == 'none':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    elif augmentation == 'basic':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    elif augmentation == 'strong':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.1))
        ])
    elif augmentation == 'autoaugment':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    
    # Test transform (no augmentation)
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    
    train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transform
    )
    test_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=test_transform
    )
     
    # Extract tensors from the datasets
    x_train_tensor = torch.stack([data[0] for data in train_dataset])  # Shape: [N, 1, 28, 28]
    y_train_tensor = torch.tensor([data[1] for data in train_dataset])  # Labels
    x_test_tensor = torch.stack([data[0] for data in test_dataset])
    y_test_tensor = torch.tensor([data[1] for data in test_dataset])

    # Save tensors to disk
    torch.save((x_train_tensor, y_train_tensor), 'snov_configs/cifar_test/cifar_trainnew.pt')
    torch.save((x_test_tensor, y_test_tensor), 'snov_configs/cifar_test/cifar_testnew.pt')
    return  

def load_data( eps, data_path = './cifar_test/cifarAug_train.pt',  a = 4, size_image = 32 ): 
    x_test_tensor, y_test_tensor = torch.load(data_path)  
    X, labels = x_test_tensor.view(-1, size_image*size_image*3), y_test_tensor    
    target_label  = torch.ones_like(labels)*a
    data_info ={}
    data_info["X"] = X
    data_info["labels"] = labels
    data_info["target_label"] = target_label
    data_info["eps"] = eps
    return data_info
 

def load_model(input_size, model_file="best_NN_cifar10_256.pth", hidden_size=256):
    model = NoSoftmaxNet(input_size, hidden_size=hidden_size) 
    state = torch.load(model_file, map_location="cpu", weights_only = False)
    # Handle both full checkpoints and pure state_dicts
    state = state.get("state_dict", state)
    model.load_state_dict(state)
    model.eval()
    print("Model loaded successfully.")
    return model

def model_info(model):
    W_list = []
    b_list = []
    in_list = []
    out_list = []
    for name, module in model.named_children():
        if isinstance(module , torch.nn.Linear): 
            W_list.append(module.weight.data.numpy())
            b_list.append(module.bias.data.numpy())
            in_list.append(module.in_features )
            out_list.append(module.out_features ) 
    return W_list, b_list, in_list , out_list  

def NLP_solver_output(u_best, over_set, under_set,    model, lb, ub,spec, lbs, ubs, I_p, I_n, I0 ,eps = 1e-5): 
    m = ConcreteModel()  
    W_list, b_list, in_list , out_list   = model_info(model) 
    label = int((spec[0][0] == 1).nonzero(as_tuple=False))
    attack =int((spec[0][0] == -1).nonzero(as_tuple=False)) #(spec[0] == 1).nonzero(as_tuple =True)[0].item()
    #attack =  (spec[0] == -1).nonzero(as_tuple=True)[0].item()   
    L = len(W_list)  
    input_size = in_list[0]
    output_size = out_list[-1] 
    lb_numpy, ub_numpy = lb.numpy(), ub.numpy()
    input_bounds = {i + 1: (float(lb_numpy[0][i]), float(ub_numpy[0][i])) for i in range(input_size)}        
    m.constraints = ConstraintList()  
    m.input = Var(RangeSet(1, input_size), domain=Reals, bounds=input_bounds  ) 
    #m.output =  Var(RangeSet(0,  output_size-1 ), domain= Reals)   
    for k in range(L-1):  
        m.add_component(f"I_{k}", RangeSet(1, in_list[k]))
        m.add_component(f"J_{k}", RangeSet(1, out_list[k])) 
        I = getattr(m, f"I_{k}")
        J = getattr(m, f"J_{k}") 
        m.add_component(f"z_{k}", Var(  J, domain= Reals))
        r = W_list[k].shape[0]
        c = W_list[k].shape[1]
        W_ini = {(i + 1, j + 1): W_list[k][i,j] for i in range(r) for j in range(c)}
        b_ini = {i +1: b_list[k][i] for i in range(r)}
        m.add_component(f"W_{k}", Param(    J,I, initialize= W_ini) )
        m.add_component(f"b_{k}", Param(  J, initialize = b_ini) )  
        lb_ini = {j + 1: lbs[k][0, j] for j in range(out_list[k])}
        ub_ini = {j + 1: ubs[k][0, j] for j in range(out_list[k])}
        m.add_component(f"lb_{k}", Param(J, initialize=lb_ini))
        m.add_component(f"ub_{k}", Param(J, initialize=ub_ini))   
        lb= getattr(m, f"lb_{k}")
        ub = getattr(m, f"ub_{k}") 
        m.add_component(f"zhat_{k}", Var(  J, domain=NonNegativeReals))
        zhat  = getattr(m, f"zhat_{k}")  
        z   = getattr(m, f"z_{k}") 
        W = getattr(m, f"W_{k}")
        b = getattr(m, f"b_{k}")
        if k == 0: 
            zhat_old = m.input
        else:
            zhat_old = getattr(m, f"zhat_{k-1}") 
        for j in J:
            m.constraints.add(z[j] == sum(W[j, i] * zhat_old[i] for i in I) + b[j])   
            m.constraints.add(z[j] <= ub[j])  
            m.constraints.add(z[j] >= lb[j]) 
            if (k,j) in I_p:
                m.constraints.add(zhat[j] ==  z[j])
            elif (k,j) in I_n:
                m.constraints.add(zhat[j] ==  0)  
            else:  
                if (k,j) in over_set:
                    lbj = float(value(lb[j]))
                    ubj = float(value(ub[j]))
                    ini_s1 = 0.5 * max(0.0, lbj + ubj)
                    m.add_component(f"s1_{k}_{j}", Var(domain=NonNegativeReals, initialize= ini_s1))
                    m.add_component(f"s2_{k}_{j}", Var(domain=NonNegativeReals, initialize=0.0))
                    s1 = getattr(m, f"s1_{k}_{j}")
                    s2 = getattr(m, f"s2_{k}_{j}")
                elif (k,j) in under_set:
                    lbj = float(value(lb[j]))
                    ubj = float(value(ub[j]))
                    ini_s2 = 0.5 * max(0.0,   -(lbj + ubj)) 
                    m.add_component(f"s1_{k}_{j}", Var(domain=NonNegativeReals, initialize=0.0))
                    m.add_component(f"s2_{k}_{j}", Var(domain=NonNegativeReals, initialize=ini_s2))
                    s1 = getattr(m, f"s1_{k}_{j}")
                    s2 = getattr(m, f"s2_{k}_{j}")
                else: 
                    m.add_component(f"s1_{k}_{j}", Var(  domain=NonNegativeReals))
                    m.add_component(f"s2_{k}_{j}", Var(  domain=NonNegativeReals))
                    s1 = getattr(m, f"s1_{k}_{j}")
                    s2 = getattr(m, f"s2_{k}_{j}")
                m.constraints.add(s1 * s2<= eps)
                m.constraints.add(zhat[j] ==  s1)
                m.constraints.add(z[j] == s1 - s2) 

    k=L-1  
    W_o ={   j+1  : (W_list[k][label, j] - W_list[k][attack, j])  for j in range(W_list[k].shape[1])} 
    b_o = {1:  (b_list[k][label] - b_list[k][attack]) }
    lb_o = {1: lbs[k][0,label] - ubs[k][0,attack]}  # min(a) - max(b)
    ub_o = {1: ubs[k][0,label] - lbs[k][0,attack]}  # max(a) - min(b)   
    m.add_component(f"lbo", Param(RangeSet(1,1), initialize= lb_o) )
    m.add_component(f"ubo", Param(    RangeSet(1,1), initialize= ub_o) ) 
    m.add_component(f"Wo", Param(    RangeSet(1,W_list[k].shape[1]), initialize= W_o) )
    m.add_component(f"bo", Param(  RangeSet(1,1), initialize = b_o) )   
    m.add_component(f"zo", Var(  RangeSet(1,1), domain= Reals))
    zo   = getattr(m, f"zo") 
    Wo = getattr(m, f"Wo")
    bo = getattr(m, f"bo")
    lbo= getattr(m, f"lbo")
    ubo = getattr(m, f"ubo")   
    m.constraints.add(zo[1]   == sum(Wo[j] * zhat[j] for j in J) + bo[1]   )   
    m.constraints.add(zo[1]  <= ubo[1]  )   
    m.constraints.add(zo[1]   >= lbo[1] )  
    #m.constraints.add( (m.output[label] - m.output[attack]) == zo[0]   )
    
    ################################### Objective function ####################  
 
    m.obj = Objective(expr=zo[1], sense = minimize)  
    ################################### Solve the problem ###################################
    
    opt = SolverFactory('ipopt')
    opt.options["max_cpu_time"] = 900  # seconds of CPU time
    start = time.time()
    solver_status = opt.solve(m, tee=True)
    runtime = time.time() - start
    tc = solver_status.solver.termination_condition 
    # Skip on any infeasibility (global or local) and other non-useful exits
    if tc in (
        TerminationCondition.infeasible, 
        TerminationCondition.infeasibleOrUnbounded,
        TerminationCondition.maxTimeLimit,
        TerminationCondition.maxIterations,
        TerminationCondition.minStepLength,
        TerminationCondition.noSolution,
        TerminationCondition.error ):
        msg = getattr(solver_status.solver, "message", "")
        print(f"[{k}] Skip: {tc.name}. {msg}") 

    if tc in (
        TerminationCondition.optimal,
        TerminationCondition.globallyOptimal,
        TerminationCondition.feasible ):
        print("Problem solved successfully.")  
        print(m.obj())
        #I_p, I_n, I0 =  find_I(m,    eps,I_p, I_n, I0)
        fmin= m.obj()  
        status =  'solvable'
        return status,   fmin, runtime, m
    elif tc == TerminationCondition.infeasible :
        print("no bad point found")
        I_p, I_n, I0 = [],[],[]
        status = 'infeasible'
        return status,   None, None, m
    elif tc == TerminationCondition.maxTimeLimit:
        print("Hit time limit")
        I_p, I_n, I0 = [],[],[]
        status = 'maxtimelimit'
        return status,   None, None, m
    else:
        print("The solver failed to solve the problem.")
        I_p, I_n, I0 = [],[],[]
        status = 'failed'
        return status,   None, None, m 

# Step 3: build the verification model
class Specification:
    def __init__(self):
        self.num_outputs = arguments.Config['data']['num_outputs']
        # FIXME Do not use numpy. Use torch instead.
        self.rhs = np.array([arguments.Config['bab']['decision_thresh']])

    def construct_vnnlib(self):
        raise NotImplementedError 

class SpecificationAllPositive(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            this_x_range = x_range[i]
            c = torch.eye(self.num_outputs).unsqueeze(0)
            new_c = []
            for ii in range(self.num_outputs):
                new_c.append((c[:, ii], self.rhs))
            vnnlib.append([(this_x_range, new_c)]) 
        return vnnlib
    
class SpecificationTarget(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list, num_class = 10):
        vnnlib = []
        for i in range(len(example_idx_list)): 
            label = int(dataset['labels'][example_idx_list[i]] )
            target_label = dataset['target_label'][example_idx_list[i]] 
            if label == target_label and label != 0:
                target_label =  label - 1
            if label == target_label and label != ( num_class - 1):
                target_label =  label + 1
            this_x_range = x_range[i]
            
            c = torch.zeros([1, self.num_outputs]) 
            c[0, label] = 1
            c[0, target_label] = -1
            print("construct c is", c)
            new_c = [(c, self.rhs)]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib



def construct_vnnlib(dataset, example_idx_list):
    """
    Simplified construct_vnnlib for L∞ norm with specify-target robustness type only.
    
    Args:
        dataset: Dictionary containing 'X' (inputs), 'eps' (perturbation), and optional bounds
        example_idx_list: List of example indices to process
    
    Returns:
        VNNLIB specification for the given examples
    """
    # Extract inputs for selected examples 
    X = dataset['X']
    
    # Only handle L∞ perturbation type
    assert arguments.Config['specification']['norm'] == float('inf'), \
        "This simplified version only supports L∞ norm"
    assert arguments.Config['specification']['robustness_type'] == 'specify-target', \
        "This simplified version only supports 'specify-target' robustness type"
    
    # Get perturbation epsilon
    perturb_epsilon = dataset['eps']
    if type(perturb_epsilon) == list:
        # Each example has different perturbations
        perturb_epsilon = torch.cat(perturb_epsilon)
        perturb_epsilon = perturb_epsilon[example_idx_list]
    
    assert perturb_epsilon is not None, "Perturbation epsilon must be provided"
    
    # Compute L∞ bounds
    if dataset.get('data_max', None) is None:
        # No explicit data bounds - use epsilon directly
        x_lower = (X[example_idx_list] - perturb_epsilon).flatten(1)
        x_upper = (X[example_idx_list] + perturb_epsilon).flatten(1)
    else:
        # Clamp bounds to respect data min/max constraints
        #x_lower = (X[example_idx_list] - perturb_epsilon).clamp(
        #    min=dataset['data_min']).flatten(1)
        #x_upper = (X[example_idx_list] + perturb_epsilon).clamp(
        #    max=dataset['data_max']).flatten(1)
        x_lower = (X[example_idx_list] - perturb_epsilon).flatten(1)
        x_upper = (X[example_idx_list] + perturb_epsilon).flatten(1)
    
    # Create range tensor for VNNLIB format
    x_range = torch.stack([x_lower, x_upper], -1).numpy()
    
    # Create all positive specification
    specification =  SpecificationTarget()
    
    # Generate VNNLIB format
    return specification.construct_vnnlib(dataset, x_range, example_idx_list)

def load_example():
    data = np.load("x_adv_example.npz")
    x_adv_flat = data["x_adv_flat"]              # shape (C*H*W,)
    shape = tuple(data["shape"])                 # (1, C, H, W)
    x_adv_restored = x_adv_flat.flatten()   # back to tensor-like shape
    return x_adv_restored

class SNOV:
    def __init__(self, args=None, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, list):
                args.append(f'--{k}')
                args.extend(list(map(str, v)))
            elif isinstance(v, bool):
                if v:
                    args.append(f'--{k}')
                else:
                    args.append(f'--no_{k}')
            else:
                args.append(f'--{k}={v}')
        arguments.Config.parse_config(args)

    def load_data_spec(self, input_size, eps, data_path,model_path, save_path ):
        run_mode = "customized_data"
        model_ori = vnnlib_all = shape = None
        file_root = expand_path(arguments.Config['general']['root_path'])
        model_ori = load_model(input_size, model_path)
        self.model_ori = model_ori
        verification_dataset  = load_data(eps, data_path )
        X  = verification_dataset["X"]
        example_idx_list = list(range(X.shape[0]))
        example_idx_list = example_idx_list[arguments.Config['data']['start']:arguments.Config['data']['end']]
        vnnlib_all = construct_vnnlib(verification_dataset, example_idx_list) 
        shape = [-1] + list(X.shape[1:]) 
        return run_mode, save_path, file_root, example_idx_list, model_ori, vnnlib_all, shape

    def compute_interm_bounds(self, model, image, lb,ub,   norm = float("inf"), method='IBP' ):
        from auto_LiRPA import BoundedModule, BoundedTensor
        from auto_LiRPA.perturbations import PerturbationLpNorm    
        lirpa_model = BoundedModule(model, torch.empty_like(image).float()) 
        ptb = PerturbationLpNorm(norm = norm, x_L=lb, x_U=ub)
        bounded_x = BoundedTensor(image, ptb)
        #lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})
        lb, ub = lirpa_model.compute_bounds(x=(bounded_x  ,), method= method) 
        save_dict = lirpa_model.save_intermediate( ) 
        bounds_name = {  '/input', '/input-3','/11'}   
        lbs= [value[0].detach().numpy() for key, value in save_dict.items() if key in bounds_name]
        ubs = [value[1].detach().numpy() for key, value in save_dict.items() if key in bounds_name]
        return lbs, ubs
    
    def classify_neurons_from_bounds(self,model, image,lb,ub,  norm = float("inf"), method='CROWN' ): 
        lbs, ubs = self.compute_interm_bounds(model, image, lb,ub,   norm = norm, method=method )
        I_p, I_n, I0 = [], [], []
            
            # Only classify hidden layers (exclude output layer)
        num_hidden_layers = len(lbs) -1
            
        for layer_idx in range(num_hidden_layers):
            if layer_idx < len(lbs) and layer_idx < len(ubs):
                layer_lbs = lbs[layer_idx][0]  # Shape: [num_neurons]
                layer_ubs = ubs[layer_idx][0]  # Shape: [num_neurons]
                        
                for neuron_idx in range(len(layer_lbs)):
                    lb_val = layer_lbs[neuron_idx]
                    ub_val = layer_ubs[neuron_idx]
                            
                    if ub_val < 0:
                                # Upper bound is negative -> always inactive (ReLU output = 0)
                        I_n.append((layer_idx, neuron_idx + 1))  # +1 for 1-indexing
                    elif lb_val > 0:
                                # Lower bound is positive -> always active (ReLU output = input)
                        I_p.append((layer_idx, neuron_idx  + 1))  # +1 for 1-indexing
                    else:
                                # Can be either active or inactive
                        I0.append((layer_idx, neuron_idx + 1 ))  # +1 for 1-indexing
        return I_p, I_n, I0, lbs, ubs 
    
    def MIP_solver( self,   model, lb, ub,spec, lbs, ubs, I_p, I_n, I0  ): 
        m = ConcreteModel()  
        W_list, b_list, in_list , out_list   = model_info(model) 
        label = int((spec[0][0] == 1).nonzero(as_tuple=False))
        attack =int((spec[0][0] == -1).nonzero(as_tuple=False))     
        L = len(W_list)  
        input_size = in_list[0] 
        lb_numpy, ub_numpy = lb.numpy(), ub.numpy()
        input_bounds = {i + 1: (float(lb_numpy[0][i]), float(ub_numpy[0][i])) for i in range(input_size)}        
        m.constraints = ConstraintList()  
        m.input = Var(RangeSet(1, input_size), domain=Reals, bounds=input_bounds  ) 
        #m.output =  Var(RangeSet(0,  output_size-1 ), domain= Reals)   
        for k in range(L-1):  
            m.add_component(f"I_{k}", RangeSet(1, in_list[k]))
            m.add_component(f"J_{k}", RangeSet(1, out_list[k])) 
            I = getattr(m, f"I_{k}")
            J = getattr(m, f"J_{k}") 
            m.add_component(f"z_{k}", Var(  J, domain= Reals))
            r = W_list[k].shape[0]
            c = W_list[k].shape[1]
            W_ini = {(i + 1, j + 1): W_list[k][i,j] for i in range(r) for j in range(c)}
            b_ini = {i +1: b_list[k][i] for i in range(r)}
            m.add_component(f"W_{k}", Param(    J,I, initialize= W_ini) )
            m.add_component(f"b_{k}", Param(  J, initialize = b_ini) )    
            lb_ini = {j + 1: lbs[k][0, j] for j in range(out_list[k])}
            ub_ini = {j + 1: ubs[k][0, j] for j in range(out_list[k])}
            m.add_component(f"lb_{k}", Param(J, initialize=lb_ini))
            m.add_component(f"ub_{k}", Param(J, initialize=ub_ini))
            lb= getattr(m, f"lb_{k}")
            ub = getattr(m, f"ub_{k}") 
            m.add_component(f"zhat_{k}", Var(  J, domain=NonNegativeReals))
            zhat  = getattr(m, f"zhat_{k}")  
            z   = getattr(m, f"z_{k}") 
            W = getattr(m, f"W_{k}")
            b = getattr(m, f"b_{k}")
            if k == 0: 
                zhat_old = m.input
            else:
                zhat_old = getattr(m, f"zhat_{k-1}") 
            for j in J:
                m.constraints.add(z[j] == sum(W[j, i] * zhat_old[i] for i in I) + b[j])   
                m.constraints.add(z[j] <= ub[j])  
                m.constraints.add(z[j] >= lb[j]) 
                if (k,j) in I_p:
                    m.constraints.add(zhat[j] ==  z[j])
                elif (k,j) in I_n:
                    m.constraints.add(zhat[j] ==  0)  
                else:  
                    m.add_component(f"v_{k}_{j}", Var(  domain= Binary))
                    v = getattr(m, f"v_{k}_{j}") 
                    m.constraints.add(zhat[j] >= 0  )
                    m.constraints.add(zhat[j] >= z[j]) 
                    m.constraints.add(zhat[j] <= ub[j]* v )
                    m.constraints.add(zhat[j] <= z[j] - lb[j]*(1-v ))   

        k=L-1   
        W_o ={   j+1  : (W_list[k][label, j] - W_list[k][attack, j])  for j in range(W_list[k].shape[1])} 
        b_o = {1:  (b_list[k][label] - b_list[k][attack]) }
        lb_o = { 1: lbs[k][0,label] - ubs[k][0,attack]  }
        ub_o = { 1: ubs[k][0,label] - lbs[k][0,attack]  }
        m.add_component(f"lbo", Param(RangeSet(1,1), initialize= lb_o) )
        m.add_component(f"ubo", Param(    RangeSet(1,1), initialize= ub_o) ) 
        m.add_component(f"Wo", Param(    RangeSet(1,W_list[k].shape[1]), initialize= W_o) )
        m.add_component(f"bo", Param(  RangeSet(1,1), initialize = b_o) )   
        m.add_component(f"zo", Var(  RangeSet(1,1), domain= Reals))
        zo   = getattr(m, f"zo") 
        Wo = getattr(m, f"Wo")
        bo = getattr(m, f"bo")
        lbo= getattr(m, f"lbo")
        ubo = getattr(m, f"ubo")   
        m.constraints.add(zo[1]   == sum(Wo[j] * zhat[j] for j in J) + bo[1]   )   
        m.constraints.add(zo[1]  <= ubo[1]  )     
        m.constraints.add(zo[1]   >= lbo[1] ) 
        #m.constraints.add( (m.output[label] - m.output[attack]) == zo[0]   )
        
        ################################### Objective function ####################  
    
        m.obj = Objective(expr=zo[1], sense = minimize)  
        ################################### Solve the problem ###################################
         
        start = time.time()
        opt = SolverFactory('gurobi')
        opt.options["TimeLimit"] = 1800 # 30 mins is the max
        solver_status = opt.solve(m, tee=True)
        runtime = time.time() - start
        if solver_status.solver.termination_condition == TerminationCondition.optimal:
            print("Problem solved successfully.")  
            print(m.obj()) 
            fmin= m.obj()     
            status = 'optimal'
        elif solver_status.solver.termination_condition == TerminationCondition.infeasible :
            print("no bad point found") 
            status = 'infeasible'
            return status,  None, runtime
        elif solver_status.solver.termination_condition == TerminationCondition.maxTimeLimit:
            print("Hit time limit.")
            return 'Timeout',  None, runtime
        else:
            print("The solver failed to solve the problem.") 
            return 'failed',  None, runtime
        
        return  status,  fmin, runtime
    
     
    # Step 4: verify through MIP
    def main(self,  input_size = 3*32*32, eps = 0.01, data_path = './cifar10_test.pt', model_path='snov_configs/cifar_test/classification.pth', save_path = 'simple_test/sol_file.txt', interm_bounds=None): 
        print(f'Experiments at {time.ctime()} on {socket.gethostname()}')
        
        # Basic setup
        torch.manual_seed(arguments.Config['general']['seed'])
        random.seed(arguments.Config['general']['seed'])
        np.random.seed(arguments.Config['general']['seed'])
        torch.set_printoptions(precision=8)
        device = arguments.Config['general']['device']
        
        if device != 'cpu':
            torch.cuda.manual_seed_all(arguments.Config['general']['seed'])
            torch.backends.cuda.matmul.allow_tf32 = False
            torch.backends.cudnn.allow_tf32 = False
        
        if arguments.Config['general']['double_fp']:
            torch.set_default_dtype(torch.float64)
        
        # Configuration updates
        bab_args = arguments.Config['bab']
        timeout_threshold = bab_args['timeout']
        
        # Load model and vnnlib specifications
        (run_mode, save_path, file_root, example_idx_list, model_ori,
        vnnlib_all, shape) = self.load_data_spec(input_size, eps, data_path = data_path,model_path = model_path, save_path = save_path)
        
        # Process each example
        for new_idx, csv_item in enumerate(example_idx_list):
            arguments.Globals['example_idx'] = new_idx
            vnnlib_id = new_idx + arguments.Config['data']['start']
            print(vnnlib_id)
            
            print(f'\n {"%"*35} idx: {new_idx}, vnnlib ID: {vnnlib_id} {"%"*35}')
            
            # Load model and vnnlib for this instance
            vnnlib = vnnlib_all[new_idx] 
            
            # Setup model and data
            model_ori.eval()
            vnnlib_shape = shape 
            x_range = torch.tensor(vnnlib[0][0])
            data_min = x_range.select(-1, 0).reshape(vnnlib_shape)      # X - eps
            data_max = x_range.select(-1, 1).reshape(vnnlib_shape)      # X + eps
            x = x_range.mean(-1).reshape(vnnlib_shape)                  # (X - eps + X + eps)/2 = X
            
            model_ori = model_ori.to(device)
            x, data_max, data_min = x.to(device), data_max.to(device), data_min.to(device)
            _, specs = vnnlib[0]
            c = torch.tensor(specs[0][0]).unsqueeze(0).to(x)   
            # Run MIP complete verifier strategy  
            
            norm = float("inf") 
            I_p, I_n, I0, lbs, ubs = self.classify_neurons_from_bounds(model_ori, x, data_min, data_max,   norm = norm, method='CROWN') 
            start = time.time() 
            MIP_status , fmin_mip , _ =  self.MIP_solver(    model_ori,data_min, data_max,c,   lbs,  ubs,  I_p.copy(),  I_n.copy(),  I0.copy()  )
            mip_time = time.time() - start
            start = time.time() 
            NLP_status0,  u_NLP, runtimeNLP, mp = NLP_solver_output(float('inf'), I_p.copy(), I_n.copy(),  model_ori,data_min, data_max,c, lbs, ubs, I_p.copy(), I_n.copy(), I0.copy() ,  eps = 1e-5) 
            nlp_time = time.time() - start 
            print("Results comparison \n")  
            print("MIP results ")
            print(MIP_status)
            print(fmin_mip)
            print(mip_time) 
            print("NLP results ")
            print(NLP_status0)
            print(u_NLP)
            print(runtimeNLP)
            print(nlp_time)
        return  MIP_status , fmin_mip, mip_time 

                


def snov_test(eps, idx_case: int, input_size: int, seed: int, device: str = "cpu",   
              verbose: bool = False) -> Dict[str, Any]:
    """
    Test function for SNOV (Scalable Near Optimal Verifier) tool.
    
    Args:
        seed (int): Random seed for reproducibility
        device (str): Device to run on ('cpu' or 'cuda')
        num_input (int): Number of input features
        verbose (bool): Whether to enable verbose output
        
    Returns:
        Dict[str, Any]: Dictionary containing test results with keys:  
            - runtime: Total execution time
            - final_lb: Final lower bound
            
    Raises:
        FileNotFoundError: If config file is not found
        ValueError: If required parameters are missing
    """
    
    # Configuration paths
    snov_config_name = "snov_configs/cifar_test/cifar_small.yaml"
    snov_config_prefix = "snov_configs/cifar_test/"
    
    # Ensure configuration file exists
    if not os.path.exists(snov_config_name):
        raise FileNotFoundError(f"Configuration file not found: {snov_config_name}")
    
    # Load configuration
    try:
        with open(snov_config_name, "r") as f:
            snov_config_dict = yaml.safe_load(f)  # Use safe_load instead of load 
    except yaml.YAMLError as e:
        raise ValueError(f"Error parsing YAML configuration: {e}")
    
    # Define file paths
    data_path = snov_config_prefix +"/cifar_trainnew.pt" 
    save_path  = snov_config_prefix +  "/snov_sol.pkl"
    model_path = snov_config_prefix +"/best_NN_cifar10_256.pth"
    # Update configuration dictionary
    snov_config_dict.update({  
        "data": {
            **snov_config_dict.get("data", {}),
            "start":  idx_case,
            "end": idx_case + 1
        }
    }) 
    
    # Save updated configuration
    try:
        with open(snov_config_name, "w") as f:
            yaml.dump(snov_config_dict, f, sort_keys=False, default_flow_style=False)
    except Exception as e:
        raise IOError(f"Error saving configuration file: {e}") 
    # Initialize and run SNOV
    snovbab = SNOV(["--config", snov_config_name])
    if verbose:
        print(f"Starting SNOV test with seed {seed} on device {device}")
      
    status , l, mip_time =     snovbab.main(input_size,eps, data_path = data_path, model_path = model_path, save_path = save_path)
        
    if verbose:
        print(f"SNOV execution completed in {mip_time:.2f} seconds")       
     
    
    # Prepare return dictionary
    return_dict = { 
        "status": status ,
        "runtime": mip_time,
        "final_lb": l, 
    }
    
    # Clean up temporary files
    temp_files = ["out.txt"]
    for temp_file in temp_files:
        if os.path.exists(temp_file):
            os.remove(temp_file)
            if verbose:
                print(f"Cleaned up temporary file: {temp_file}")
    
    return return_dict


def run_snov_experiments(input_size: int, seeds: list, device: str = "cpu",  
                        verbose: bool = False )  :
    """
    Run SNOV tests across multiple seeds and collect results.
    
    Args:
        seeds (list): List of random seeds to test
        device (str): Device to run on
        num_input (int): Number of input features
        verbose (bool): Whether to enable verbose output
        
    Returns:
        pd.DataFrame: DataFrame containing results for all seeds
    """
    eps = 0.01   
    results = [] 
    seed = 42
    idx_list = [i for i in range(100)] 
    print(idx_list)
    infeasible_list = []
    failed_list = []
    other_list = []
    for i in range(len(idx_list)):
        idx_case = idx_list[i]
        result = snov_test(eps, idx_case, input_size,seed, device,   verbose  )
        if result['status'] == 'optimal': 
            result['case'] = idx_case
            results.append(result) 
        elif result['status'] == 'infeasible':
            infeasible_list.append(result)
        elif result['status'] == 'failed':
            failed_list.append(result)
        else:
            print(result['status'])
            other_list.append(result)

    return  results, infeasible_list, failed_list , other_list

if __name__ == "__main__":
    # Example of how to use the improved functions
    save_datasets()
    test_seeds = [42]#, 123, 456, 789, 999] 
    # Run experiments (uncomment when dependencies are available)
    input_size = 3*32*32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    results, infeasible_list, failed_list , other_list = run_snov_experiments(input_size, test_seeds, device=device,   verbose=False)
    print("\nExperiment Results of MIP solver Summary:")
    print( infeasible_list)
    print(failed_list)
    print(results )
    print(other_list)
    

