#########################################################################  
##   Copyright (C) 2021-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
import  random,os,time,gc,torch, socket,re
import numpy as np
import torch.nn as nn
from collections import defaultdict
from torch.utils.data import  TensorDataset 
import arguments
from auto_LiRPA import BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm
from auto_LiRPA.utils import stop_criterion_all, stop_criterion_batch_any
from auto_LiRPA.operators.convolution import BoundConv
from beta_CROWN_solver import LiRPANet
from specifications import (trim_batch, batch_vnnlib, sort_targets, prune_by_idx, add_rhs_offset)  
from bab_snov  import general_bab
from input_split.batch_branch_and_bound import input_bab_parallel 
from specifications import construct_vnnlib
from utils import  expand_path
from utils import Logger, print_model

 
import torch
import math,copy
import matplotlib.pyplot as plt 
import seaborn as sns
import pandas as pd
import yaml
import time
import pickle
import os
import sys 
from typing import Dict, Any   
import sys, atexit

# Step 1: Export the results
 
 
import logging

logging.basicConfig(
    filename='stdout_ours.log',
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)



def _close_streams():
    # restore first (optional), then close
    sys.stdout.flush(); sys.stderr.flush()
    sys.stdout, sys.stderr = _stdout_orig, _stderr_orig
    f_out.close(); f_err.close()

#atexit.register(_close_streams)

  
# Step 2: load data

def load_data_old( eps, data_path = './cifar_test.pt',  a = 4, size_image = 32 ): 
    x_test_tensor, y_test_tensor = torch.load(data_path) 
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
    X, labels = test_dataset[:][0].view(-1, size_image*size_image*3), test_dataset[:][1]    
    target_label  = torch.ones_like(labels)*a
    data_info ={}
    data_info["X"] = X
    data_info["labels"] = labels
    data_info["target_label"] = target_label
    data_info["eps"] = eps
    return data_info

def load_data( eps, data_path = './cifar_test/cifarAug_train.pt',  a = 4, size_image = 32 ): 
    x_test_tensor, y_test_tensor = torch.load(data_path)  
    X, labels = x_test_tensor.view(-1, size_image*size_image*3), y_test_tensor    
    target_label  = torch.ones_like(labels)*a
    data_info ={}
    data_info["X"] = X
    data_info["labels"] = labels
    data_info["target_label"] = target_label
    data_info["eps"] = eps
    return data_info
 
# Step 3: load model
def load_model(input_size, model_file="best_NN_cifar10_256.pth", hidden_size=256):
    model = NoSoftmaxNet(input_size, hidden_size=hidden_size) 
    state = torch.load(model_file, map_location="cpu", weights_only = False)
    # Handle both full checkpoints and pure state_dicts
    state = state.get("state_dict", state)
    model.load_state_dict(state)
    model.eval()
    print("Model loaded successfully.")
    return model 

# Step Verification model
class NoSoftmaxNet(nn.Module):
    #define layers of neural network
    def __init__(self, input_size, hidden_size = 256):
        super().__init__()
        self.hidden1  = nn.Linear(input_size, hidden_size)
        self.hidden2  = nn.Linear(hidden_size, hidden_size)
        self.output  = nn.Linear(hidden_size, 10)
        self.relu = nn.ReLU()

    #define forward pass of neural network
    def forward(self, x):
        x = self.hidden1(x)
        x = self.relu(x)
        x = self.hidden2(x)
        x = self.relu(x)
        x = self.output(x)
        return x
    
class Specification:
    def __init__(self):
        self.num_outputs = arguments.Config['data']['num_outputs']
        # FIXME Do not use numpy. Use torch instead.
        self.rhs = np.array([arguments.Config['bab']['decision_thresh']])

    def construct_vnnlib(self):
        raise NotImplementedError 

class SpecificationAllPositive(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list):
        vnnlib = []
        for i in range(len(example_idx_list)):
            this_x_range = x_range[i]
            c = torch.eye(self.num_outputs).unsqueeze(0)
            new_c = []
            for ii in range(self.num_outputs):
                new_c.append((c[:, ii], self.rhs))
            vnnlib.append([(this_x_range, new_c)]) 
        return vnnlib
    
class SpecificationTarget(Specification):
    def construct_vnnlib(self, dataset, x_range, example_idx_list, num_class = 10):
        vnnlib = []
        for i in range(len(example_idx_list)): 
            label = int(dataset['labels'][example_idx_list[i]] )
            target_label = dataset['target_label'][example_idx_list[i]] 
            if label == target_label and label != 0:
                target_label =  label - 1
            if label == target_label and label != ( num_class - 1):
                target_label =  label + 1
            this_x_range = x_range[i]
            
            c = torch.zeros([1, self.num_outputs]) 
            c[0, label] = 1
            c[0, target_label] = -1
            print("construct c is", c)
            new_c = [(c, self.rhs)]
            vnnlib.append([(this_x_range, new_c)])
        return vnnlib


def construct_vnnlib(dataset, example_idx_list):
    """
    Simplified construct_vnnlib for L∞ norm with specify-target robustness type only.
    
    Args:
        dataset: Dictionary containing 'X' (inputs), 'eps' (perturbation), and optional bounds
        example_idx_list: List of example indices to process
    
    Returns:
        VNNLIB specification for the given examples
    """
    # Extract inputs for selected examples 
    X = dataset['X']
    
    # Only handle L∞ perturbation type
    assert arguments.Config['specification']['norm'] == float('inf'), \
        "This simplified version only supports L∞ norm"
    assert arguments.Config['specification']['robustness_type'] == 'specify-target', \
        "This simplified version only supports 'specify-target' robustness type"
    
    # Get perturbation epsilon
    perturb_epsilon = dataset['eps']
    if type(perturb_epsilon) == list:
        # Each example has different perturbations
        perturb_epsilon = torch.cat(perturb_epsilon)
        perturb_epsilon = perturb_epsilon[example_idx_list]
    
    assert perturb_epsilon is not None, "Perturbation epsilon must be provided"
    
    # Compute L∞ bounds
    if dataset.get('data_max', None) is None:
        # No explicit data bounds - use epsilon directly
        x_lower = (X[example_idx_list] - perturb_epsilon).flatten(1)
        x_upper = (X[example_idx_list] + perturb_epsilon).flatten(1)
    else:
        # Clamp bounds to respect data min/max constraints
        #x_lower = (X[example_idx_list] - perturb_epsilon).clamp(
        #    min=dataset['data_min']).flatten(1)
        #x_upper = (X[example_idx_list] + perturb_epsilon).clamp(
        #    max=dataset['data_max']).flatten(1)
        x_lower = (X[example_idx_list] - perturb_epsilon).flatten(1)
        x_upper = (X[example_idx_list] + perturb_epsilon).flatten(1)
    
    # Create range tensor for VNNLIB format
    x_range = torch.stack([x_lower, x_upper], -1).numpy()
    
    # Create all positive specification
    specification =  SpecificationTarget()
    
    # Generate VNNLIB format
    return specification.construct_vnnlib(dataset, x_range, example_idx_list)

class SNOV:
    def __init__(self, args=None, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, list):
                args.append(f'--{k}')
                args.extend(list(map(str, v)))
            elif isinstance(v, bool):
                if v:
                    args.append(f'--{k}')
                else:
                    args.append(f'--no_{k}')
            else:
                args.append(f'--{k}={v}')
        arguments.Config.parse_config(args)

    def load_data_spec(self, input_size, eps, data_path,model_path, save_path ):
        run_mode = "customized_data"
        model_ori = vnnlib_all = shape = None
        file_root = expand_path(arguments.Config['general']['root_path'])
        model_ori = load_model(input_size, model_path)
        self.model_ori = model_ori
        verification_dataset  = load_data(eps, data_path )
        X  = verification_dataset["X"]
        example_idx_list = list(range(X.shape[0]))
        example_idx_list = example_idx_list[arguments.Config['data']['start']:arguments.Config['data']['end']]
        vnnlib_all = construct_vnnlib(verification_dataset, example_idx_list) 
        shape = [-1] + list(X.shape[1:]) 
        return run_mode, save_path, file_root, example_idx_list, model_ori, vnnlib_all, shape

    def incomplete_verifier(
        self,
        model_ori,
        data,
        data_ub=None,
        data_lb=None,
        vnnlib=None,
        interm_bounds=None
    ):
        # Generally, c should be constructed from vnnlib
        assert len(vnnlib) == 1, 'incomplete_verifier only support single x spec'
        input_x, specs = vnnlib[0]
        c_transposed = False
        tighten_input_bounds = (
            arguments.Config['solver']['invprop']['tighten_input_bounds']
        )
        apply_output_constraints_to = (
            arguments.Config['solver']['invprop']['apply_output_constraints_to']
        )

        if len(specs) > 1: 
            # single OR with many clauses (e.g., robustness verification)
            assert all([len(_[0]) == 1 for _ in specs]), \
                'for each property in OR, only one clause supported so far'
            c = torch.concat([
                item[0] if isinstance(item[0], torch.Tensor) else torch.tensor(item[0])
                for item in specs], dim=0).unsqueeze(1).to(data)  # c shape: (batch, 1, num_outputs)
            do_transpose = not arguments.Config['solver']['optimize_disjuncts_separately']
            rhs = torch.tensor(np.array([item[1] for item in specs])).to(data)  # (batch, 1)
            if do_transpose and c.shape[0] != 1 and data.shape[0] == 1:
                # transpose c to shape (1,batch,num_outputs) to share intermediate bounds
                assert len(apply_output_constraints_to) == 0, (
                    'To apply output constraints, set --optimize_disjuncts_separately'
                )
                c = c.transpose(0, 1)
                rhs = rhs.t()  # (1, batch)
                c_transposed = True 
            else:
                if arguments.Config['solver']['prune_after_crown']:
                    raise NotImplementedError(
                        'To use optimize_disjuncts_separately=True, do not set '
                        'prune_after_crown=True'
                    )
            stop_func = stop_criterion_all(rhs)

        else:
            # single AND with many clauses (e.g., Yolo).
            # shape: (batch=1, num_clauses in AND, num_outputs)
            c = torch.tensor(specs[0][0]).unsqueeze(0).to(data)
            # shape: (1, num_clauses in AND)
            rhs = torch.tensor(specs[0][1], dtype=data.dtype, device=data.device).unsqueeze(0)
            stop_func = stop_criterion_batch_any(rhs)

        model = LiRPANet(model_ori, in_size=data.shape, c=c)

        bound_prop_method = arguments.Config['solver']['bound_prop_method']
        if len(apply_output_constraints_to) > 0:
            assert bound_prop_method == 'alpha-crown'
            model.net.constraints = torch.tensor([x[0] for x in specs])
            assert model.net.constraints.ndim == 3
            assert rhs.ndim ==  2
            if len(specs) == 1:
                assert rhs.size(0) == 1
                model.net.thresholds = rhs.squeeze(0)
            else:
                assert rhs.size(1) == 1
                model.net.thresholds = rhs.squeeze(1)

            # We need to use matrix mode for the layer that should utilize output constraints
            for node in model.net.nodes():
                if node.are_output_constraints_activated_for_layer(apply_output_constraints_to):
                    if isinstance(node, BoundConv) and node.mode == 'patches':
                        node.mode = 'matrix'

        norm = arguments.Config['specification']['norm']
            # Perturbation value for non-Linf perturbations, None for all other cases.
        ptb = PerturbationLpNorm(norm=norm, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb).to(data.device)
        output =  model.net(x) 
        print('Original output:', output)

        # save output
        if arguments.Config['general']['save_output']:
            arguments.Globals['out']['pred'] = output.cpu()

        domain = torch.stack([data_lb.squeeze(0), data_ub.squeeze(0)], dim=-1)
        # one of them is sufficient.
        global_lb, ret = model.build(
            domain, x, stop_criterion_func=stop_func, decision_thresh=rhs, vnnlib_ori=vnnlib,
            interm_bounds=interm_bounds)

        if arguments.Config['general']['return_optimized_model']:
            return model

        if c_transposed:
            # transpose back to get ready for general verified condition check and final outputs
            global_lb = global_lb.t()
            rhs = rhs.t()

        if torch.any((global_lb - rhs) > 0, dim=-1).all():
            # Any spec in AND verified means verified. Also check all() at batch dim.
            print('verified with init bound!')
            return 'safe-incomplete', {}, model

        if arguments.Config['attack']['pgd_order'] == 'middle':
            if ret['attack_images'] is not None:
                return 'unsafe-pgd', {}, model

        # Save the alpha variables during optimization. Here the batch size is 1.
        saved_alphas = defaultdict(dict)
        for m in model.net.optimizable_activations:
            for spec_name, alpha in m.alpha.items():
                # each alpha size is (2, spec, 1, *shape); batch size is 1.
                saved_alphas[m.name][spec_name] = alpha.detach().clone()

        # FIXME there may be some duplicate with saved_alphas
        if bound_prop_method == 'alpha-crown':
            ret['activation_opt_params'] = {
                node.name: node.dump_optimized_params()
                for node in model.net.optimizable_activations
            }

        if c_transposed:
            ret['lower_bounds'][model.final_name] = ret['lower_bounds'][model.final_name].t()
            ret['upper_bounds'][model.final_name] = ret['upper_bounds'][model.final_name].t()
            if ret['lA'] is not None:
                ret['lA'] = {k: v.transpose(0, 1) for k, v in ret['lA'].items()}

        ret.update({'model': model, 'global_lb': global_lb, 'alpha': saved_alphas})

        if tighten_input_bounds:
            perturbed_root = None
            for root in model.net.roots():
                if hasattr(root, 'perturbation') and root.perturbation is not None:
                    assert perturbed_root is None, (
                        'BaB based on tightened bounds currently supports only one input layer'
                    )
                    perturbed_root = root
            assert perturbed_root is not None
            ret['tightened_input_bounds'] = [
                perturbed_root.perturbation.x_L.detach(),
                perturbed_root.perturbation.x_U.detach(),
            ]
        return 'unknown', ret, model_ori, global_lb

    def bab(self, data_lb, data_ub, c, rhs,
            data=None, targets=None, vnnlib=None, timeout=None,
            time_stamp=0, data_dict=None, lower_bounds=None, upper_bounds=None,
            reference_alphas=None, attack_images=None, cplex_processes=None,
            activation_opt_params=None, reference_lA=None,
            model_incomplete=None, refined_betas=None,
            create_model=True, model=None, return_domains=False,
            max_iterations=None, property_idx=None, vnnlib_meta=None,
            orig_lirpa_model=None):
        # This will use the refined bounds if the complete verifier is 'bab-refine'.
        # FIXME do not repeatedly create LiRPANet which creates a new BoundedModule each time.

        # Save these arguments in case that they need to retrieved the next time
        # this function is called.
        if vnnlib_meta is None:
            vnnlib_meta = {
                'property_idx': 0, 'vnnlib_id': 0, 'benchmark_name': None
            }
        self.data_lb, self.data_ub, self.c, self.rhs = data_lb, data_ub, c, rhs
        self.data, self.targets, self.vnnlib = data, targets, vnnlib

        # if using input split, transpose C if there are multiple specs with shared input,
        # to improve efficiency when calling the incomplete verifier later
        if arguments.Config['bab']['branching']['input_split']['enable']:
            c_transposed = False
            if (data_lb.shape[0] == 1 and data_ub.shape[0] == 1 and c is not None
                    and c.shape[0] > 1 and c.shape[1] == 1):
                # multiple c instances (multiple vnnlibs) since c.shape[0] > 1,
                # but they share the same input (since data.shape[0] == 1）and
                # only single spec in each instance (c.shape[1] == 1)
                c = c.transpose(0, 1)
                rhs = rhs.transpose(0, 1)
                c_transposed = True

        if create_model:
            self.model = LiRPANet(
                model, c=c, cplex_processes=cplex_processes,
                in_size=(data_lb.shape if len(targets) <= 1
                        else [len(targets)] + list(data_lb.shape[1:])),  
            )
            if not model_incomplete:
                print_model(self.model.net) 
        data_lb, data_ub = data_lb.to(self.model.device), data_ub.to(self.model.device)
        norm = arguments.Config['specification']['norm']
        if data_dict is not None:
            assert isinstance(data_dict['eps'], float)
            ptb = PerturbationLpNorm(
                norm=norm, eps=data_dict['eps'],
                eps_min=data_dict.get('eps_min', 0), x_L=data_lb, x_U=data_ub)
        else:
            ptb = PerturbationLpNorm(norm=norm, x_L=data_lb, x_U=data_ub)

        if data is not None:
            data = data.to(self.model.device)
            x = BoundedTensor(data, ptb).to(data_lb.device)
            output = self.model.net(x).flatten()
            print('Model prediction is:', output)

            # save output:
            if arguments.Config['general']['save_output']:
                arguments.Globals['out']['pred'] = output.cpu()

            if arguments.Config['attack']['check_clean'] and not arguments.Config['debug'][
                'sanity_check']:
                clean_rhs = c.matmul(output)
                print(f'Clean RHS: {clean_rhs}')
                if (clean_rhs < rhs).any():
                    # add and set output batch_size dimension to 1
                    verified_status, _ = check_and_save_cex(
                        x.detach(), output.unsqueeze(0), vnnlib,
                        arguments.Config['attack']['cex_path'], 'unsafe')
                    return -torch.inf, None, verified_status
        else:
            x = BoundedTensor(data_lb, ptb).to(data_lb.device)

        self.domain = torch.stack([data_lb.squeeze(0), data_ub.squeeze(0)], dim=-1)
        if arguments.Config['bab']['branching']['input_split']['enable']:
            result = input_bab_parallel(
                self.model, self.domain, x, rhs=rhs,
                timeout=timeout, max_iterations=max_iterations,
                vnnlib=vnnlib, c_transposed=c_transposed,
                return_domains=return_domains, vnnlib_meta=vnnlib_meta
            )
            if return_domains:
                return result
        else:
            assert not return_domains, 'return_domains is only for input split for now'
            
            result = general_bab(    self.model_ori,  data_lb,  data_ub,c,
                self.model, self.domain, x,
                refined_lower_bounds=lower_bounds, refined_upper_bounds=upper_bounds,
                activation_opt_params=activation_opt_params, reference_lA=reference_lA,
                reference_alphas=reference_alphas, attack_images=attack_images,
                timeout=timeout, max_iterations=5000,
                refined_betas=refined_betas, rhs=rhs, property_idx=property_idx,
                model_incomplete=model_incomplete, time_stamp=time_stamp)

         
        min_lb = result[0]
        if min_lb is None:
            min_lb = -torch.inf
        elif isinstance(min_lb, torch.Tensor):
            min_lb = min_lb.item()
        result = (min_lb, *result[1:8])
        return result 
    def complete_verifier(
            self, model_ori, model_incomplete, vnnlib, batched_vnnlib, vnnlib_shape,
            index, timeout_threshold, bab_ret=None, cplex_processes=None,
            attack_images=None, attack_margins=None, results=None, vnnlib_id=None,
            benchmark_name=None, orig_lirpa_model=None
    ):
        """Simplified complete verifier focusing on core BaB verification."""
        start_time = time.time()
        
        enable_incomplete = arguments.Config['general']['enable_incomplete_verification']
        init_global_lb = results.get('global_lb', None)
        lower_bounds = results.get('lower_bounds', None)
        upper_bounds = results.get('upper_bounds', None)
        reference_alphas = results.get('alpha', None)
        lA = results.get('lA', None)
        cplex_cuts = (arguments.Config['bab']['cut']['enabled']
                    and arguments.Config['bab']['cut']['cplex_cuts'])

        reference_alphas_cp = None
        if enable_incomplete:
            final_name = model_incomplete.final_name
            init_global_ub = upper_bounds[final_name]
            print('lA shape:', [lAitem.shape for lAitem in lA.values()])
            (batched_vnnlib, init_global_lb, init_global_ub,
            lA, attack_images) = sort_targets(
                batched_vnnlib, init_global_lb, init_global_ub,
                attack_images, attack_margins, results, model_incomplete)
            if reference_alphas is not None:
                reference_alphas_cp = copy.deepcopy(reference_alphas)

        solved_c_rows = []

        # Initialize return list if not provided
        if bab_ret is None:
            bab_ret = []
        
        # Initialize results if not provided
        if results is None:
            results = {}
        
        time_stamp = 0
        rhs_offsets = arguments.Config['specification']['rhs_offset']
        # Iterate through per VNNLIB property
        
        for property_idx, properties in enumerate(batched_vnnlib):  # loop of x
            print(f'\nProperties batch {property_idx}, size {len(properties[0])}') 
            # Handle timeout management
            timeout = timeout_threshold - (time.time() - start_time)
            print(f'Remaining timeout: {timeout}')
            
            if timeout <= 0:
                print('Timeout reached before processing property')
                return 'unknown'
            
            start_time_bab = time.time()
            print(f'Verifying property {property_idx} with {len(properties[0])} instances.')
            rhs_offset = 0 if rhs_offsets is None else rhs_offsets
            if (arguments.Config['bab']['cut']['enabled'] and
                arguments.Config['bab']['initial_max_domains'] == 1
                and not arguments.Config['debug']['sanity_check']):
                if init_global_lb[property_idx][0] > rhs_offset:
                    print('Verified by alpha-CROWN bound!')
                    continue
            # Process different input formats (dict and tensor based)
            if isinstance(properties[0][0], dict):
                # Dictionary-based format
                def _get_item(properties, key):
                    return torch.concat([
                        item[key].unsqueeze(0) for item in properties[0]], dim=0)
                
                x = _get_item(properties, 'X')
                data_min = _get_item(properties, 'data_min')
                data_max = _get_item(properties, 'data_max') 
                for item in properties[0]:
                    assert item['eps'] == properties[0][0]['eps']
                data_dict = {
                    'eps': properties[0][0]['eps'],
                    'eps_min': properties[0][0].get('eps_min', 0),
                }
            else:
                # Tensor-based format
                
                x_range  = torch.as_tensor(properties[0], dtype=torch.get_default_dtype())  # [N_props, d, 2] or [d,2]
                data_min = x_range.select(-1, 0).reshape(vnnlib_shape)   # X - eps (no clamp)
                data_max = x_range.select(-1, 1).reshape(vnnlib_shape)   # X + eps (no clamp)
                x        = x_range.mean(-1).reshape(vnnlib_shape)        # X
                data_dict = None 

            if 'tightened_input_bounds' in results:
                assert (
                    results['tightened_input_bounds'][0][property_idx:property_idx+1].shape
                    == data_min.shape
                )
                data_min = results['tightened_input_bounds'][0][property_idx:property_idx+1]
                data_max = results['tightened_input_bounds'][1][property_idx:property_idx+1]

            target_label_arrays = list(properties[1])  # properties[1]: (c, rhs, y, pidx)
            assert len(target_label_arrays) == 1
            c, rhs, pidx = target_label_arrays[0] 
            this_spec_attack_images = None
            print('c is', c)
            # FIXME Clean up.
            # Shape and type of rhs is very confusing
            rhs = torch.tensor(rhs, device=arguments.Config['general']['device'],
                               dtype=torch.get_default_dtype())
            
            if enable_incomplete and len(init_global_lb) > 1:
                # no need to trim_batch if batch = 1
                ret_trim = trim_batch(
                    model_incomplete, init_global_lb, init_global_ub,
                    reference_alphas_cp, lower_bounds, upper_bounds,
                    reference_alphas, lA, property_idx, c, rhs)
                lA_trim, rhs = ret_trim['lA'], ret_trim['rhs']
                trimmed_lower_bounds = ret_trim['lower_bounds']
                trimmed_upper_bounds = ret_trim['upper_bounds']
            else:
                lA_trim = lA.copy() if lA is not None else lA
                trimmed_lower_bounds = lower_bounds
                trimmed_upper_bounds = upper_bounds

            print(f'##### Instance {index} first 10 spec matrices: ')
            print(f'{c[:10]}\nthresholds: {rhs.flatten()[:10]} ######')

            torch.cuda.empty_cache()
            gc.collect()
            c = c.to(rhs)  # both device and dtype

            time_stamp  += 1
            input_split = arguments.Config['bab']['branching']['input_split']['enable']
            init_failure_idx = np.array([])
            if enable_incomplete and not input_split:
                if len(init_global_lb) > 1:  # if batch == 1, there is no need to filter here.
                    # Reuse results from incomplete results, or from refined MIPs.
                    # skip the prop that already verified
                    rlb = trimmed_lower_bounds[final_name]
                    # The following flatten is dangerous, each clause in OR only
                    # has one output bound.
                    assert len(rlb.shape) == len(rhs.shape) == 2
                    assert rlb.shape[1] == rhs.shape[1] == 1
                    init_verified_cond = rlb.flatten() > rhs.flatten()
                    init_verified_idx = torch.where(init_verified_cond)[0]
                    if len(init_verified_idx) > 0:
                        print('Initial alpha-CROWN verified for spec index '
                                f'{init_verified_idx} with bound '
                                f'{rlb[init_verified_idx].squeeze()}.')
                        l = init_global_lb[init_verified_idx].tolist()
                        bab_ret.append([index, l, 0, time.time() - start_time_bab, pidx])
                    init_failure_idx = torch.where(~init_verified_cond)[0]
                    if len(init_failure_idx) == 0:
                        # This batch of x verified by init opt crown
                        continue
                    print(f'Remaining spec index {init_failure_idx} with '
                            f'bounds {rlb[init_failure_idx]} need to verify.')

                    (reference_alphas, lA_trim, x, data_min, data_max,
                    trimmed_lower_bounds, trimmed_upper_bounds, c) = prune_by_idx(
                        reference_alphas, init_verified_cond, final_name, lA_trim, x,
                        data_min, data_max, lA is not None,
                        trimmed_lower_bounds, trimmed_upper_bounds, c) 
                l, nodes,ret, stats, u, snov_time, global_list, u_best_list = self.bab(
                    data=x, targets=init_failure_idx, time_stamp=time_stamp,
                    data_ub=data_max, data_lb=data_min, data_dict=data_dict,
                    lower_bounds=trimmed_lower_bounds, upper_bounds=trimmed_upper_bounds,
                    c=c, reference_alphas=reference_alphas, cplex_processes=cplex_processes,
                    activation_opt_params=results.get('activation_opt_params', None),
                    refined_betas=results.get('refined_betas', None), rhs=rhs[0:1],
                    reference_lA=lA_trim, attack_images=None,
                    model_incomplete=model_incomplete, timeout=timeout, vnnlib=vnnlib,
                    model=model_ori, property_idx=property_idx,
                    vnnlib_meta={
                        'property_idx': property_idx,
                        'vnnlib_id': vnnlib_id,
                        'benchmark_name': benchmark_name
                    },
                    orig_lirpa_model=orig_lirpa_model,
                )
                bab_ret.append([index, float(l), nodes,
                                time.time() - start_time_bab,
                                init_failure_idx.tolist()])
            else:
                assert arguments.Config['general']['complete_verifier'] == 'bab'
                assert not arguments.Config['bab']['attack']['enabled'], (
                    'BaB-attack must be used with incomplete verifier.')
                # input split also goes here directly
                l, nodes,ret, stats, u, snov_time, global_list, u_best_list= self.bab(
                    data=x, targets=pidx, time_stamp=time_stamp,
                    data_ub=data_max, data_lb=data_min, c=c, data_dict=data_dict,
                    cplex_processes=None,
                    rhs=rhs, timeout=timeout, attack_images=None,
                    vnnlib=vnnlib, model=model_ori, vnnlib_meta={
                        'property_idx': property_idx,
                        'vnnlib_id': vnnlib_id,
                        'benchmark_name': benchmark_name
                    },
                    orig_lirpa_model=orig_lirpa_model,
                )
                bab_ret.append([index, l, nodes, time.time() - start_time_bab, pidx])  

            timeout = timeout_threshold - (time.time() - start_time)  
            if ret == 'unsafe':
                return 'unsafe-bab', l, u, snov_time, global_list, u_best_list
            elif ret == 'unknown' or timeout < 0:
                return 'unknown', l, u,snov_time, global_list, u_best_list
            elif ret != 'safe':
                raise ValueError(f'Unknown return value of bab: {ret}')
         
        return 'safe', l , u, snov_time, global_list, u_best_list
 

    def main(self,eps, input_size = 3*32*32,   data_path = './cifar10_test.pt', model_path='snov_configs/cifar_test/classification.pth', save_path = 'simple_test/sol_file.txt', interm_bounds=None): 
        print(f'Experiments at {time.ctime()} on {socket.gethostname()}')
        
        # Basic setup
        torch.manual_seed(arguments.Config['general']['seed'])
        random.seed(arguments.Config['general']['seed'])
        np.random.seed(arguments.Config['general']['seed'])
        torch.set_printoptions(precision=8)
        device = arguments.Config['general']['device']
        
        if device != 'cpu':
            torch.cuda.manual_seed_all(arguments.Config['general']['seed'])
            torch.backends.cuda.matmul.allow_tf32 = False
            torch.backends.cudnn.allow_tf32 = False
        
        if arguments.Config['general']['double_fp']:
            torch.set_default_dtype(torch.float64)
        
        # Configuration updates
        bab_args = arguments.Config['bab']
        timeout_threshold = bab_args['timeout']
        
        # Load model and vnnlib specifications
        (run_mode, save_path, file_root, example_idx_list, model_ori,
        vnnlib_all, shape) = self.load_data_spec(input_size, eps=eps, data_path = data_path,model_path = model_path, save_path = save_path)
        # Initialize logger
        self.logger = Logger(run_mode, save_path, timeout_threshold)
        
        # Process each example
        for new_idx, csv_item in enumerate(example_idx_list):
            arguments.Globals['example_idx'] = new_idx
            vnnlib_id = new_idx + arguments.Config['data']['start']
            print(vnnlib_id)
            
            print(f'\n {"%"*35} idx: {new_idx}, vnnlib ID: {vnnlib_id} {"%"*35}')
            self.logger.record_start_time()
            
            # Load model and vnnlib for this instance
            vnnlib = vnnlib_all[new_idx] 
            
            # Setup model and data
            model_ori.eval()
            vnnlib_shape = shape 
            if isinstance(vnnlib[0][0], dict):
                x = vnnlib[0][0]['X'].reshape(vnnlib_shape)
                data_min = vnnlib[0][0]['data_min'].reshape(vnnlib_shape)
                data_max = vnnlib[0][0]['data_max'].reshape(vnnlib_shape)
            else:
                x_range = torch.tensor(vnnlib[0][0])
                data_min = x_range.select(-1, 0).reshape(vnnlib_shape)      # X - eps
                data_max = x_range.select(-1, 1).reshape(vnnlib_shape)      # X + eps
                x = x_range.mean(-1).reshape(vnnlib_shape)                  # (X - eps + X + eps)/2 = X
            
            model_ori = model_ori.to(device)
            x, data_max, data_min = x.to(device), data_max.to(device), data_min.to(device)
            
            # Initialize verification status
            verified_status = 'unknown'
            verified_success = False
            model_incomplete = None
            ret = {}
            orig_lirpa_model = None
            
            # Run incomplete verification
            if arguments.Config['general']['enable_incomplete_verification']:
                start = time.time()   
                verified_status, ret, orig_lirpa_model , l = self.incomplete_verifier(
                    model_ori,
                    x,
                    data_ub=data_max,
                    data_lb=data_min,
                    vnnlib=vnnlib,
                    interm_bounds=interm_bounds
                ) 
                    
                #verified_status, ret, orig_lirpa_model , l= incomplete_verification_output 
                verified_success = verified_status != 'unknown'
                model_incomplete = ret.get('model', None)
                
                print(f'Incomplete verification result: {verified_status}') 
            # Run BaB complete verifier strategy if not already verified
            if not verified_success:
                # Prepare for BaB verification
                batched_vnnlib =  batch_vnnlib(vnnlib) 
                # Run complete verifier (BaB) 
                start = time.time()  
                verified_status, l , u, snov_time, global_list, u_best_list = self.complete_verifier(     model_ori, model_incomplete, vnnlib, batched_vnnlib, vnnlib_shape,
                    new_idx, bab_ret=self.logger.bab_ret, cplex_processes=None,
                    timeout_threshold=timeout_threshold - (time.time() - self.logger.start_time),
                    attack_images=None, attack_margins=None, results=ret, 
                    vnnlib_id=vnnlib_id, benchmark_name=None, orig_lirpa_model=orig_lirpa_model
                )
                snov_time = time.time() - start
        return verified_status, l  , u,  snov_time, global_list, u_best_list

                


def snov_test(eps,idx_case: int, input_size: int, seed: int, device: str = "cpu",   
              verbose: bool = False) -> Dict[str, Any]:
    """
    Test function for SNOV (Scalable Near Optimal Verifier) tool.
    
    Args:
        seed (int): Random seed for reproducibility
        device (str): Device to run on ('cpu' or 'cuda')
        num_input (int): Number of input features
        verbose (bool): Whether to enable verbose output
        
    Returns:
        Dict[str, Any]: Dictionary containing test results with keys:  
            - runtime: Total execution time
            - final_lb: Final lower bound
            
    Raises:
        FileNotFoundError: If config file is not found
        ValueError: If required parameters are missing
    """
    
    # Configuration paths
    snov_config_name = "snov_configs/cifar_test/cifar_small.yaml"
    snov_config_prefix = "snov_configs/cifar_test/"
    
    # Ensure configuration file exists
    if not os.path.exists(snov_config_name):
        raise FileNotFoundError(f"Configuration file not found: {snov_config_name}")
    
    # Load configuration
    try:
        with open(snov_config_name, "r") as f:
            snov_config_dict = yaml.safe_load(f)  # Use safe_load instead of load 
    except yaml.YAMLError as e:
        raise ValueError(f"Error parsing YAML configuration: {e}")
    
    # Define file paths
    data_path = snov_config_prefix +"/cifar_trainnew.pt" 
    save_path  = snov_config_prefix +  "/snov_sol.pkl"
    model_path = snov_config_prefix +"/best_NN_cifar10_256.pth"
    # Update configuration dictionary
    snov_config_dict.update({  
        "data": {
            **snov_config_dict.get("data", {}),
            "start":  idx_case,
            "end": idx_case + 1
        }
    }) 
    
    # Save updated configuration
    try:
        with open(snov_config_name, "w") as f:
            yaml.dump(snov_config_dict, f, sort_keys=False, default_flow_style=False)
    except Exception as e:
        raise IOError(f"Error saving configuration file: {e}") 
    # Initialize and run SNOV
    snovbab = SNOV(["--config", snov_config_name])
    if verbose:
        print(f"Starting SNOV test with seed {seed} on device {device}")
      
    status, l, u, snov_time, global_list, u_best_list =     snovbab.main(eps,input_size, data_path = data_path, model_path = model_path, save_path = save_path)
     
    if verbose:
        print(f"SNOV execution completed in {snov_time:.2f} seconds")       
     
    
    # Prepare return dictionary
    return_dict = { 
        "status": status ,
        "runtime": snov_time,
        "final_lb": l,
        "final_ub": u,
        "global_list": global_list,
        "u_best": u_best_list
    }
    
    # Clean up temporary files
    temp_files = ["out.txt"]
    for temp_file in temp_files:
        if os.path.exists(temp_file):
            os.remove(temp_file)
            if verbose:
                print(f"Cleaned up temporary file: {temp_file}")
    
    return return_dict
  

def run_snov_experiments(input_size: int, seeds: list, device: str = "cpu",  
                        verbose: bool = False )  :
    """
    Run SNOV tests across multiple seeds and collect results.
    
    Args:
        seeds (list): List of random seeds to test
        device (str): Device to run on
        num_input (int): Number of input features
        verbose (bool): Whether to enable verbose output
        #
    Returns:
        pd.DataFrame: DataFrame containing results for all seeds
    """
    eps = 0.03
    safe_list = [] 
    seed = 42
    idx_list = [42]  
    print(idx_list)
    unsafe_list = []
    unknown_list = []
    other_list = []
    time_list = []
    for i in range(len(idx_list)):
        idx_case = idx_list[i]
        result = snov_test(eps,idx_case, input_size,seed, device,   verbose  )
        if result['status'] == 'unsafe-bab': 
            result['case'] = idx_case
            unsafe_list.append(result) 
        elif result['status'] == 'unknown':
            unknown_list.append(result)
        elif result['status'] == 'safe':
            safe_list.append(result)
        else:
            print(result['status'])
            other_list.append(result)
        time_list.append(result['runtime'])
    return  safe_list, unsafe_list, unknown_list, other_list, time_list

if __name__ == "__main__":
    # Example of how to use the improved functions
    test_seeds = [42]#, 123, 456, 789, 999] 
    # Run experiments (uncomment when dependencies are available)
    input_size = 3*32*32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    safe_list, unsafe_list, unknown_list, other_list, time_list  = run_snov_experiments(input_size, test_seeds, device=device,   verbose=False)
    print("\nExperiment Results of MIP solver Summary:") 
    print( safe_list)
    print(unsafe_list)
    print(unknown_list )
    print(other_list)
    print(time_list)

logging.info("This message is saved to file")
print("This still prints to terminal")
