import copy
import logging
import torch

from ab_crown_utils import write_conf, parse_safety_status, load_x_from_file, run_ab_crown
from config import config
from models.model import DupNet, TripledAdvMnistNet, mnist_config, \
    tripled_patch_mnist_model, dup_patch_mnist_model, dup_patch_cifar10_model, \
    mnist_patching_conf, dup_patch_gtsrb_model, dup_gtsrb_net, get_mnist_neuron_global_order, dup_taxinet_net, \
    dup_patch_taxinet_model
from models.model import dup_cifar10_net
import numpy as np

from models.model_factory import create_comps_from_mask

# Ensure logging is configured if not already set up
if not logging.getLogger().hasHandlers():
    logging.basicConfig(
        level=logging.INFO,
        format='[%(asctime)s][%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler()]
    )

def get_spec_name(metric, verify_patching_only=False):
    mapping = {'same_winner': 'verified-acc',
               'winner_runner': 'dupnet-runnerup',
               'logit_difference': 'logit-diff',
                'abs_max': 'abs-max',
               'tripled_adv_same_winner_iteratively': 'tripled-adv-same-winner-iteratively',
               'tripled_adv_winner_runner': 'tripled-adv-winner-runner',
               'tripled_adv_hybrid': 'tripled-adv-hybrid',
               'tripled_adv_same_winner': 'tripled-adv-same-winner',
               'winner_diff': 'dupnet-winner-diff'}

    patching_only_mapping = {'winner_runner': 'patching-runnerup',
                        'winner_diff': 'patching-winner-diff'}
    if metric not in mapping:
        raise ValueError(f"Unknown metric: {metric}")
    if verify_patching_only:
        return patching_only_mapping[metric]
    return mapping[metric]

def qmsc_formal_patching_mnist(dataset, full_net, X, full_net_path, device, prune_by, metric, epsilon, delta, exp_paths,
                                patching, patch_eps, bottom_up=True, query_timeout=300, verify_patching_only=False, verbose=True, **kwargs):
    """minimal qmsc binary search for quasi-minimal sufficient circuit under formal patching (tripled: verify_patching_only=False case).
    accepts extraneous **kwargs for interface symmetry with other pruning functions.
    assumptions:
      1. k=0 mask (no neurons patched) must satisfy the spec (YES).
      2. k=total mask (all neurons patched) must violate the spec (NO).
    bottom_up: when True, request bottom-up global neuron ordering from get_mnist_neuron_global_order(True);
               when False, use top-down (original) order.
    returns: pruned_net (copy), components dict, final YES mask tensor, timing_info.
    """
    if verify_patching_only == True: raise ValueError("qmsc_formal_patching_mnist currently supports only verify_patching_only=False")
    if patching != 'formal': raise ValueError("qmsc_formal_patching_mnist supports only patching='formal'")
    if dataset != 'mnist' and prune_by != 'neurons': raise ValueError("qmsc_formal_patching_mnist supports only dataset='mnist' and prune_by='neurons'")
    total_neurons, layer_offset = mnist_patching_conf['total_neurons'], mnist_patching_conf['layer_offset']
    order = get_mnist_neuron_global_order(bottom_up) # respect flag instead of fixed False
    if verbose: logging.info(f"[QMSC] Using neuron order bottom_up={bottom_up}")
    X = X.to(device)

    def query_mask(mask_tensor):
        return formal_patch_query(dataset='mnist', pruned_net=full_net, X=X, full_net_path=full_net_path,
            saved_dup_patch_net_path=exp_paths['saved_dup_patch_net_path_mnist'], Z_mask=mask_tensor,
            device=device, metric=metric, epsilon=epsilon, delta=delta, patch_eps=patch_eps,
            exp_paths=exp_paths, adv_x_path=exp_paths['adv_x_path'], verify_patching_only=False,
            query_timeout=query_timeout, verbose=verbose)

    # Checking Full net (no patching)
    base_mask = torch.zeros(total_neurons)
    base_safe, _ = query_mask(base_mask)
    if not base_safe: raise RuntimeError("QMSC invalid: unpatched network does not satisfy spec")
    last_yes_mask = base_mask.clone()

    # Checking empty net (all patched)
    full_mask = torch.ones(total_neurons)
    full_safe, _ = query_mask(full_mask)
    if full_safe: raise RuntimeError("QMSC abort: fully patched network still satisfies spec; no breaking point")

    low, high = 0, total_neurons
    evaluations = 2
    while high - low > 1:
        mid = (low + high) // 2
        mid_mask = torch.zeros(total_neurons)
        mid_mask[order[:mid]] = 1.0
        safe_mid, _ = query_mask(mid_mask)
        evaluations += 1
        if safe_mid:
            low = mid
            last_yes_mask = mid_mask
            if verbose: logging.info(f"[QMSC] YES at k={mid}")
        else:
            high = mid
            if verbose: logging.info(f"[QMSC] NO at k={mid}")

    breaking_neuron = order[low] if low < total_neurons else None
    qsmc_info = {"evaluations": evaluations, "yes_k": low, "no_k": high, "breaking_neuron": breaking_neuron}
    logging.info(f"[QMSC] done yes_k={low} no_k={high} breaking_neuron={breaking_neuron} evals={evaluations}")
    torch.save(last_yes_mask, exp_paths['Z_mask_file'])
    return full_net, create_comps_from_mask(dataset, last_yes_mask), qsmc_info


def formal_approval_mnist(pruned_net_path, X, full_net_path, saved_dupnet_weights, device, metric, epsilon,
                          delta, exp_paths, adv_x_path=None, query_timeout=300, **kwargs):

    if not adv_x_path: adv_x_path = exp_paths['adv_x_path']

    input_size, hidden_size_1, hidden_size_2, num_classes = (
        mnist_config["input_size"], mnist_config["hidden_size_1"], mnist_config["hidden_size_2"],
        mnist_config["num_classes"])
    dupnet = DupNet(input_size, hidden_size_1, hidden_size_2, num_classes).to(device)
    dupnet.net1.load_state_dict(torch.load(full_net_path, weights_only=False, map_location=device))
    dupnet.net2.load_state_dict(torch.load(pruned_net_path, weights_only=False, map_location=device))
    torch.save(dupnet.state_dict(), saved_dupnet_weights)
    specification = get_spec_name(metric)

    write_conf(exp_paths['customized_models_paths'], saved_dupnet_weights, adv_x_path,
               exp_paths['mnist_sample_path'], exp_paths['abcrown_specification_path'], device,
               specification=specification, epsilon=epsilon, delta=delta, model_type='mnist', timeout=query_timeout)
    full_stdout = run_ab_crown(exp_paths["abcrown_specification_path"])
    is_safe, verification_res = parse_safety_status(full_stdout)
    logging.info(f"Verification status: is_safe: {is_safe} , verification_res: {verification_res}")
    return is_safe, verification_res


###### todo refactor into original approval, to avoid code duplication
def formal_approval_vision(pruned_net_path, X, full_net_path, saved_cifar_dupnet_path, device, metric,
                           epsilon, delta, exp_paths, query_timeout=300, model_type='cifar10-big', adv_x_path=None, ):

    if not adv_x_path: adv_x_path = exp_paths['adv_x_path']
    if model_type == 'taxinet':
        sample_path = exp_paths['taxinet_sample_path']
        dupnet = dup_taxinet_net().to(device)
    elif model_type == 'gtsrb':
        sample_path = exp_paths['gtsrb_sample_path']
        dupnet = dup_gtsrb_net().to(device)
    else:
        sample_path = exp_paths['cifar10_sample_path']
        dupnet = dup_cifar10_net(model_type=model_type).to(device)

    ckpt_full = torch.load(full_net_path, weights_only=False, map_location=device)
    dupnet.net1.load_state_dict(ckpt_full.get("state_dict", ckpt_full))
    ckpt_pruned = torch.load(pruned_net_path, weights_only=False, map_location=device)
    dupnet.net2.load_state_dict(ckpt_pruned.get("state_dict", ckpt_pruned))
    torch.save(dupnet.state_dict(), saved_cifar_dupnet_path)
    specification = get_spec_name(metric)


    write_conf(exp_paths['customized_models_paths'], saved_cifar_dupnet_path, adv_x_path,
               sample_path, exp_paths['abcrown_specification_path'], device,
               model_type=model_type, specification=specification, epsilon=epsilon, delta=delta, timeout=query_timeout)
    full_stdout = run_ab_crown(exp_paths['abcrown_specification_path'])
    is_safe, verification_res = parse_safety_status(full_stdout)
    logging.info(f"Verification status: is_safe: {is_safe} , verification_res: {verification_res}")
    return is_safe, verification_res

### TODO consider impelmenting patching
def formal_prune_mnist(dataset, full_net, X, full_net_path, device, prune_by, metric, epsilon, delta, exp_paths, patching, patch_eps=None, data_dist=None, weight_contrib=None,
                       bottom_up=True, query_timeout=300, verify_patching_only=True, **kwargs):
    """
    -------
    pruned_net : nn.Module
        A copy of `net` with certain weights or neurons zeroed out.
    """
    logging.info(f"Formal mnist Pruning by: {prune_by}, Patching: {patching}, Verify Patching Only: {verify_patching_only},  Metric: {metric}, dataset:{dataset}, epsilon:{epsilon}, Delta: {delta}, Bottom-up: {bottom_up}, full_net_path: {full_net_path}, exp_paths: {exp_paths}, query_timeout: {query_timeout}")

    if prune_by not in ['neurons']: raise ValueError("Invalid value for prune_by.")

    if metric not in ['winner_runner', 'same_winner', 'logit_difference', 'abs_max', 'winner_diff',
                     'patching_winner_runner', 'patching_winner_diff']:
        raise ValueError("Invalid metric. Use: 'winner_runner', 'same_winner', 'logit_difference', 'abs_max', 'winner_diff', 'patching_winner_runner', or 'patching_winner_diff'.")

    if patching not in ['formal', 'zero', 'mean']: raise NotImplementedError(f"Patching method '{patching}' is not implemented.")

    if patching == 'mean' and prune_by != 'neurons':
        raise ValueError("Mean patching only supported for prune_by='neurons'.")

    if patching == 'mean' and prune_by == 'neurons' and data_dist is None:
        raise ValueError("patching='mean' requires data_dist for neuron pruning.")

    # Clone the network to keep the original unaffected
    pruned_net = copy.deepcopy(full_net).to(device).eval()
    X = X.to(device)
    dead_components = []
    active_components = []
    timeouts = []

    layers_order = [("fc1", pruned_net.fc1), ("fc2", pruned_net.fc2), ("fc3", pruned_net.fc3)]
    Z_mask = torch.zeros(mnist_patching_conf['total_neurons'])

    if bottom_up: layers_order = list(reversed(layers_order))

    logging.info(f"starting pruning")

    for name, layer in (layers_order):
        out_features = layer.out_features

        if prune_by == 'neurons':
            for neuron_idx in range(out_features):
                old_weight = layer.weight[neuron_idx].clone()
                old_bias = layer.bias[neuron_idx].clone()



                if patching != 'formal': # use standard formal verification
                    # patch the neuron
                    if patching == 'zero':
                        layer.weight.data[neuron_idx, :] = 0.0
                        layer.bias.data[neuron_idx] = 0.0
                    elif patching == 'mean':
                        layer.weight.data[neuron_idx, :] = 0.0
                        layer.bias.data[neuron_idx] = data_dist[name][neuron_idx]

                    pruned_net_path = exp_paths['pruned_net_path']
                    torch.save(pruned_net.state_dict(), pruned_net_path)
                    formally_approved, ver_res = formal_approval_mnist(
                        pruned_net_path, X, full_net_path,
                        exp_paths["path_to_save_mnist_dupnet"],
                        device, metric=metric, epsilon=epsilon, delta=delta, exp_paths=exp_paths,
                        query_timeout=query_timeout)
                else: #use formal patching verification
                    mask_idx = mnist_patching_conf['layer_offset'][name] + neuron_idx # set the mask for this neuron
                    Z_mask[mask_idx] = 1.0
                    # logging.info(f"mask idx: {mask_idx}, mask: {Z_mask}")
                    formally_approved, ver_res = formal_patch_query(
                        dataset='mnist', pruned_net=pruned_net, X=X,
                        full_net_path=full_net_path,
                        saved_dup_patch_net_path=exp_paths['saved_dup_patch_net_path_mnist'],
                        Z_mask=Z_mask, device=device, metric=metric,
                        epsilon=epsilon, delta=delta, patch_eps=patch_eps,
                        exp_paths=exp_paths, adv_x_path=exp_paths['adv_x_path'], verify_patching_only=verify_patching_only, query_timeout=query_timeout
                    )

                # record neuron as timed out if verification timed out TODO is that desiable recording of tout
                if ver_res.get('timeout', 0) != 0 and ver_res.get('unsafe') == 0:
                    timeouts.append((name, neuron_idx))

                if not formally_approved:
                    logging.info(f"not pruning neuron {neuron_idx}, layer {name}")
                    # revert patching
                    if patching != 'formal':  # restore weights
                        layer.weight.data[neuron_idx] = old_weight
                        layer.bias.data[neuron_idx] = old_bias
                    else:
                        Z_mask[mask_idx] = 0.0  # Reset the mask for this neuron
                    active_components.append((name, neuron_idx))
                else:
                    logging.info(f"pruning neuron {neuron_idx}, layer {name}")
                    dead_components.append((name, neuron_idx))

    # Save Z_mask if using formal patching
    if patching=='formal': torch.save(Z_mask, exp_paths['Z_mask_file'])

    return pruned_net, {"active": active_components, "dead": dead_components, "granularity":prune_by}, timeouts


def formal_prune_vision(dataset, full_net, X, full_net_path, device, prune_by,  # 'conv_channels' or 'conv_heads'
                        metric, patching, epsilon, delta, exp_paths, patch_eps=None, verify_patching_only=True, bottom_up=True, verbose=False, query_timeout=300, **kwargs):
    logging.info(f"formal {dataset} pruning by: {prune_by}, patching: {patching}, verify-patching-only: {verify_patching_only},  metric: {metric}, epsilon:{epsilon}, dataset: {dataset}, delta: {delta}, bottom-up: {bottom_up}, verbose: {verbose}, device:{device}, full_net_path: {full_net_path}, exp_paths: {exp_paths}")
    if prune_by not in ['conv_channels', 'conv_heads']: raise ValueError("Invalid prune_by. Use: 'conv_channels' or 'conv_heads'.")

    if metric not in ['winner_runner', 'same_winner', 'logit_difference', 'abs_max', 'winner_diff']:
        raise ValueError("Invalid metric. Use: 'winner_runner', 'same_winner', 'logit_difference', or 'abs_max'.")

    if patching not in ['formal', 'zero']: raise NotImplementedError("non-zero patching is not implemented.")
    if patching == 'formal' and patch_eps is None: raise ValueError("patch_eps is required for formal patching.")

    ## todo set by default to none. not applicable for informal patching
    if patching == 'formal' and not verify_patching_only: raise ValueError(f"verify_patching_only=False is not supported for {dataset} dataset")

    active_components = []
    dead_components = []
    timeouts = []

    pruned_net = copy.deepcopy(full_net).to(device).eval()
    X = X.to(device)

    logging.info(f"Starting formal pruning with granularity '{prune_by}'")

    conv_layers = [(name, layer) for name, layer in pruned_net.named_modules() if isinstance(layer, torch.nn.Conv2d) and 'downsample' not in name.lower() and 'shortcut' not in name.lower()]
    if bottom_up:
        conv_layers = list(reversed(conv_layers))


    if patching == 'formal': # initialize Z_mask for formal patching
        Z_mask = {name: [0.0] * layer.weight.shape[0] for name, layer in conv_layers}

    for name, layer in conv_layers:
        if prune_by == 'conv_channels':
            out_channels = layer.weight.shape[0]
            for out_idx in range(out_channels):
                old_weight = layer.weight[out_idx].clone()
                old_bias = layer.bias[out_idx].clone() if layer.bias is not None else None

                if patching != 'formal':  # use standard formal verification
                    # zero out the channel
                    layer.weight.data[out_idx] = 0.0
                    if layer.bias is not None: layer.bias.data[out_idx] = 0.0
                    pruned_net_path = exp_paths['pruned_net_path']
                    torch.save(pruned_net.state_dict(), pruned_net_path)

                    formally_approved, ver_res = formal_approval_vision(
                        pruned_net_path, X, full_net_path,
                        exp_paths[f'path_to_save_{dataset}_dupnet'],
                        device, metric, epsilon, delta, exp_paths,
                        query_timeout=query_timeout, model_type=dataset
                    )
                else:  # use formal patching verification
                    Z_mask[name][out_idx] = 1.0  # mark this channel as patched
                    formally_approved, ver_res = formal_patch_query(
                        dataset=dataset, pruned_net=pruned_net, X=X,
                        full_net_path=full_net_path,
                        saved_dup_patch_net_path=exp_paths[f'saved_dup_patch_net_path_{dataset}'],
                        Z_mask=Z_mask, device=device, metric=metric,
                        epsilon=epsilon, delta=delta, patch_eps=patch_eps,
                        exp_paths=exp_paths, adv_x_path=exp_paths['adv_x_path'], verify_patching_only=verify_patching_only, query_timeout=query_timeout
                    )

                # record channel as timed out if verification timed out
                if ver_res.get('timeout', 0) != 0 and ver_res.get('unsafe', 0) == 0:
                    timeouts.append((name, out_idx))

                if not formally_approved:
                    logging.info(f"[NOT PRUNED] filter {out_idx} in layer {name}")
                    # revert patching
                    if patching != 'formal':
                        layer.weight.data[out_idx] = old_weight
                        if layer.bias is not None: layer.bias.data[out_idx] = old_bias
                    else:
                        Z_mask[name][out_idx] = 0.0  # reset the mask for this channel
                    active_components.append((name, out_idx))
                else:
                    logging.info(f"[PRUNED] filter {out_idx} in layer {name}")
                    dead_components.append((name, out_idx))

        elif prune_by == 'conv_heads':
            old_weight = layer.weight.clone()
            old_bias = layer.bias.clone() if layer.bias is not None else None

            if patching != 'formal':  # use standard formal verification
                # zero out the head and save net
                layer.weight.data[:] = 0.0
                if layer.bias is not None:
                    layer.bias.data[:] = 0.0
                pruned_net_path = exp_paths['pruned_net_path']
                torch.save(pruned_net.state_dict(), pruned_net_path)
                approved, ver_res = formal_approval_vision(pruned_net_path, X, full_net_path,
                                                           exp_paths[f'path_to_save_{dataset}_dupnet'],
                                                           device, metric, epsilon, delta, exp_paths,
                                                           query_timeout=query_timeout, model_type=dataset)
            else:  # use formal patching verification
                Z_mask[name] = [1.0] * layer.weight.shape[0]  # mark this head as patched
                approved, ver_res = formal_patch_query(
                    dataset=dataset, pruned_net=pruned_net, X=X,
                    full_net_path=full_net_path,
                    saved_dup_patch_net_path=exp_paths[f'saved_dup_patch_net_path_{dataset}'],
                    Z_mask=Z_mask, device=device, metric=metric,
                    epsilon=epsilon, delta=delta, patch_eps=patch_eps,
                    exp_paths=exp_paths, adv_x_path=exp_paths['adv_x_path'], verify_patching_only=verify_patching_only, query_timeout=query_timeout
                )

            if ver_res.get('timeout', 0) != 0 and ver_res.get('unsafe', 0) == 0:
                timeouts.append(name)

            if not approved:
                logging.info(f"not pruning conv head {name}")
                # revert patching
                if patching != 'formal': # restore weights
                    layer.weight.data[:] = old_weight
                    if layer.bias is not None:
                        layer.bias.data[:] = old_bias
                else:
                    Z_mask[name] = [0.0] * layer.weight.shape[0]  # reset the mask for this head
                active_components.append(name)
            else:
                logging.info(f"pruned conv head {name}")
                dead_components.append(name)

    # save Z_mask if using formal patching
    if patching == 'formal': torch.save(Z_mask, exp_paths['Z_mask_file'])

    return pruned_net, {"active": active_components, "dead": dead_components, "granularity": prune_by}, timeouts

# TODO remove pruned_net argument, not used
def formal_patch_query(dataset, pruned_net, X, full_net_path, saved_dup_patch_net_path, Z_mask, device, metric, epsilon,
                       delta, patch_eps, exp_paths, adv_x_path, verify_patching_only, query_timeout=300, verbose=True):
    if dataset not in ['mnist', 'cifar10-small', 'gtsrb', 'taxinet']: raise ValueError("Unsupported dataset. Use 'mnist' or 'cifar10-small'.")

    if not verify_patching_only: # verifying both patching and input robustness. tripled patching network is used
        if dataset == 'mnist':
            Z_mask = [Z_mask.tolist()]
            tripled_patch_net = tripled_patch_mnist_model(Z_mask).to(device)
        else:
            raise NotImplementedError(f"dataset {dataset} is not supported for tripled formal patch query yet.")

        full_net_ckpt = torch.load(full_net_path, weights_only=False, map_location=device)
        tripled_patch_net.patch_net.load_state_dict(full_net_ckpt.get("state_dict", full_net_ckpt))
        tripled_patch_net.pruned_net.load_state_dict(full_net_ckpt.get("state_dict", full_net_ckpt))
        tripled_patch_net.full_net.load_state_dict(full_net_ckpt.get("state_dict", full_net_ckpt))
        torch.save(tripled_patch_net.state_dict(), saved_dup_patch_net_path)
        specification = get_spec_name(metric)
        write_conf(exp_paths['customized_models_paths'], saved_dup_patch_net_path, adv_x_path,
                   exp_paths['mnist_sample_path'] if dataset == 'mnist' else exp_paths['cifar10_sample_path'],
                   exp_paths['abcrown_specification_path'], device, specification=specification, epsilon=epsilon,
                   delta=delta, patch_eps=patch_eps,
                   Z_mask=Z_mask, model_type=dataset, verify_patching_only=verify_patching_only, timeout=query_timeout)
    else: # verify only patching robustness. duplicated patching network is used
        if dataset == 'mnist':
            Z_mask = [Z_mask.tolist()]
            dup_patch_net = dup_patch_mnist_model(Z_mask=Z_mask).to(device)
            sample_path = exp_paths['mnist_sample_path']
        elif dataset == 'cifar10-small':
            dup_patch_net = dup_patch_cifar10_model(Z_mask=Z_mask).to(device)
            sample_path = exp_paths['cifar10_sample_path']
        elif dataset == 'taxinet':
            dup_patch_net = dup_patch_taxinet_model(Z_mask=Z_mask).to(device)
            sample_path = exp_paths['taxinet_sample_path']
        elif dataset == 'gtsrb':
            dup_patch_net = dup_patch_gtsrb_model(Z_mask=Z_mask).to(device)
            sample_path = exp_paths['gtsrb_sample_path']

        full_net_ckpt = torch.load(full_net_path, weights_only=False, map_location=device)
        dup_patch_net.patch_net.load_state_dict(full_net_ckpt.get("state_dict", full_net_ckpt))
        dup_patch_net.pruned_net.load_state_dict(full_net_ckpt.get("state_dict", full_net_ckpt))
        torch.save(dup_patch_net.state_dict(), saved_dup_patch_net_path)
        specification = get_spec_name(metric, verify_patching_only=True)
        write_conf(exp_paths['customized_models_paths'], saved_dup_patch_net_path, adv_x_path,
                   sample_path,
                   exp_paths['abcrown_specification_path'], device, specification=specification, epsilon=1e-5,  # 1e-5 input robustness is not defined, so we set epsilon to a very small value
                   delta=delta, patch_eps=patch_eps, Z_mask=Z_mask, model_type=dataset, verify_patching_only=verify_patching_only, timeout=query_timeout)

    full_stdout = run_ab_crown(exp_paths["abcrown_specification_path"])
    is_safe, verification_res = parse_safety_status(full_stdout)
    if verbose: logging.info(f"Verification status: is_safe: {is_safe} , verification_res: {verification_res}")
    return is_safe, verification_res


def load_and_modify_sample_target(sample_path, new_target):
    """ loads a sample, modifies its target, and saves it back."""
    data = np.load(sample_path, allow_pickle=True).item()
    true_label = data['label']
    data['target'] = new_target
    np.save(sample_path, data)
    logging.info(f"Modified MNIST sample saved to {sample_path} with new target: {new_target}")
    return true_label

def find_adv_example_per_class_iteratively(adv_x_path, device, saved_tripnet_weights, epsilon, exp_paths):
    true_label = np.load(exp_paths['mnist_sample_path'], allow_pickle=True).item()['label']
    for i in range(mnist_config["num_classes"]):  # check for every logit
        if i == true_label:
            continue
        load_and_modify_sample_target(exp_paths['mnist_sample_path'], i)
        write_conf(exp_paths['customized_models_paths'], saved_tripnet_weights, adv_x_path,
                   exp_paths['mnist_sample_path'], exp_paths['abcrown_specification_path'], device,
                   specification='tripled-adv-label-target', epsilon=epsilon, model_type='mnist')


        logging.info(f"running verification for target {i}")
        print(f"running verification for target {i}")
        full_stdout = run_ab_crown(config["paths"]["abcrown_specification_path"])
        is_safe, verification_res = parse_safety_status(full_stdout)
        logging.info(f"Verification for target {i}: is_safe: {is_safe}, ver_res: {verification_res}")
        if not is_safe:
            logging.info(f"[SUCCESS] found adversarial example for target {i}")
            return False
    return True

#### todo implement for cifar10 too
def find_adversarial_same_winner(dataset, informal_pruned_net_path, formal_pruned_net_path, full_net_path, x, adv_x_path, metric, device, epsilon, exp_paths):
    logging.info(f"[INFO] running find_adversarial_same_winner | dataset: {dataset}, informal_pruned_net_path: {informal_pruned_net_path}, formal_pruned_net_path: {formal_pruned_net_path},  adv_x_path: {adv_x_path}, metric: {metric}, device: {device}, epsilon: {epsilon}")
    if dataset in ['cifar10', 'cifar10-small', 'cifar10-big', 'gtsrb', 'taxinet']:
        raise NotImplementedError("CNN find_adversarial_same_winner is not implemented yet.")

    saved_tripnet_weights = exp_paths['path_to_save_mnist_tripnet']
    input_size, hidden_size_1, hidden_size_2, num_classes = (mnist_config["input_size"], mnist_config["hidden_size_1"], mnist_config["hidden_size_2"],mnist_config["num_classes"])
    tripnet = TripledAdvMnistNet(input_size, hidden_size_1, hidden_size_2, num_classes).to(device)

    # net1 = full net
    tripnet.net1.load_state_dict(torch.load(full_net_path, weights_only=False, map_location=device))
    # net2 = formal pruned net
    tripnet.net2.load_state_dict(torch.load(formal_pruned_net_path, weights_only=False, map_location=device))
    # net3 = informal pruned net
    tripnet.net3.load_state_dict(torch.load(informal_pruned_net_path, weights_only=False, map_location=device))

    torch.save(tripnet.state_dict(), saved_tripnet_weights)
    tripnet.eval()

    if metric == 'tripled_adv_same_winner_iteratively':
        return find_adv_example_per_class_iteratively(adv_x_path, device, saved_tripnet_weights, epsilon, exp_paths)

    else:
        enable_incomplete = False if metric == 'tripled_adv_same_winner' else True
        specification = get_spec_name(metric)
        write_conf(exp_paths['customized_models_paths'], saved_tripnet_weights, adv_x_path,
                   exp_paths['mnist_sample_path'],
                   exp_paths['abcrown_specification_path'], device,
                   specification=specification, epsilon=epsilon, enable_incomplete=enable_incomplete, model_type=dataset)
        full_stdout = run_ab_crown(exp_paths['abcrown_specification_path'])
        is_safe, verification_res = parse_safety_status(full_stdout)
        logging.info(f"Verification status: is_safe: {is_safe} , verification_res: {verification_res}")
        return is_safe

def find_adversarial_example(dataset, informal_pruned_net_path, formal_pruned_net_path, full_net_path, x, adv_x_path, metric, device,
                             epsilon,  exp_paths, delta=None, **kwargs):
    logging.info(f"[INFO] running find_adversarial_example | dataset: {dataset}, informal_pruned_net_path: {informal_pruned_net_path}, formal_pruned_net_path: {formal_pruned_net_path}, adv_x_path: {adv_x_path}, metric: {metric}, device: {device}, epsilon: {epsilon}, delta: {delta}")

    if metric in ['tripled_adv_same_winner_iteratively', 'tripled_adv_winner_runner', 'tripled_adv_hybrid', 'tripled_adv_same_winner']:
        is_safe = find_adversarial_same_winner(dataset, informal_pruned_net_path, formal_pruned_net_path, full_net_path, x, adv_x_path, metric, device, epsilon, exp_paths)
    else:
        is_safe, _ = find_formal_adv_example(dataset, informal_pruned_net_path, formal_pruned_net_path, x, adv_x_path, metric, device, epsilon, delta, exp_paths, **kwargs)

    logging.info(f"[INFO] adversarial query result: {is_safe}")
    print(f"[INFO] adversarial query result: {is_safe}")
    if not is_safe:
        print(f"[SUCCESS] found adversarial example")
        logging.info(f"[SUCCESS] found adversarial example")
        return load_x_from_file(dataset, adv_x_path)
    else:
        print(f"[FAILURE] did not find adversarial example")
        logging.info(f"[FAILURE] did not find adversarial example")
        return


def find_formal_adv_example(dataset, pruned_net_path, full_net_path, X, adv_x_path, metric, device, epsilon, delta, exp_paths, query_timeout=300, **kwargs):
    logging.info(f"Formal adversarial query for dataset: {dataset}, pruned_net_path: {pruned_net_path}, full_net_path: {full_net_path}, metric: {metric}, device: {device}, epsilon: {epsilon}, delta: {delta}")
    formal_approval_fn = formal_approval_mnist if dataset == 'mnist' else formal_approval_vision
    dupnet_path = exp_paths[f"path_to_save_{dataset}_dupnet"]

    is_safe, ver_res = formal_approval_fn(pruned_net_path, X, full_net_path,
                                          dupnet_path, device, metric=metric, epsilon=epsilon, delta=delta,
        adv_x_path=adv_x_path, exp_paths=exp_paths, query_timeout=query_timeout, model_type=dataset,
    )

    logging.info(f"Formal approval status: is_safe: {is_safe}")
    logging.info(f"Formal approval status: ver_res: {ver_res}")

    return is_safe,  ver_res

