import os

QUERY_TIMEOUT = 6

# ensure the tests directory exists for file outputs
os.makedirs("tests", exist_ok=True)
os.makedirs("tests/abcrown_logs", exist_ok=True)  # ensure log directory exists

from informal_pruning import informal_mnist_pruning, informal_mnist_kl_div_pruning, generate_perturbations
from formal_pruning import formal_prune_mnist, find_formal_adv_example
from informal_criteria import winner_runner_criterion, abs_max_criterion
from config import config
import logging
from ab_crown_utils import load_x_from_file
from models.model import mnist_config
import torch

# use dataset_utils for all model and sample loading
from dataset_utils import load_dataset_and_model, load_mnist_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def setup_logging():
    log_file = "tests/main.log"
    # Remove all handlers associated with the root logger object.
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=logging.INFO,
        format='[%(asctime)s][%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file, mode='a'),
            logging.StreamHandler()
        ]
    )

setup_logging()


INF_PRUNED_PATH = "tests/tmp_informal_pruned.pth"
FORMAL_PRUNED_PATH = "tests/tmp_formal_pruned.pth"
SAMPLE_ID = 7437
DELTA = 4.36
EPSILON = 0.02

def test_informal_pruning_winner_runner():
    # use a real mnist model and a real sample from the dataset
    exp_paths = {
        'pruned_net_path': "tests/tmp_pruned.pth",
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    # get model and a real sample
    model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[SAMPLE_ID]
    )
    x = X
    net = model
    pruned_net, active = informal_mnist_pruning(
        full_net=net,
        X=x,
        device=device,
        prune_by='neurons',
        metric='winner_runner',
        patching='zero',
        delta=DELTA,
        bottom_up=True
    )
    # Save the informal pruned model for later tests
    torch.save(pruned_net.state_dict(), INF_PRUNED_PATH)
    # check that output still satisfies winner_runner_criterion
    with torch.no_grad():
        orig = net(x)
        pruned = pruned_net(x)
        assert winner_runner_criterion(orig, pruned, DELTA), "Informal pruning failed winner_runner criterion"
    print("test_informal_pruning_winner_runner: PASSED")

def test_informal_pruning_absmax():
    exp_paths = {
        'pruned_net_path': "tests/tmp_pruned.pth",
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[1]
    )
    x = X
    net = model
    pruned_net, active = informal_mnist_pruning(
        full_net=net,
        X=x,
        device=device,
        prune_by='neurons',
        metric='abs_max',
        patching='zero',
        delta=DELTA,
        bottom_up=True
    )
    with torch.no_grad():
        orig = net(x)
        pruned = pruned_net(x)
        assert abs_max_criterion(orig, pruned, DELTA), "Informal pruning failed abs_max criterion"
    print("test_informal_pruning_absmax: PASSED")

def test_informal_kl_div_pruning():
    exp_paths = {
        'pruned_net_path': "tests/tmp_pruned.pth",
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[2]
    )
    x = X
    net = model
    tau = 0.5
    pruned_net, active = informal_mnist_kl_div_pruning(
        full_net=net,
        x=x,
        device=device,
        tau=tau,
        prune_by='neurons',
        patching='zero',
        bottom_up=True
    )
    # just check that pruned_net is a valid model and can run
    with torch.no_grad():
        out = pruned_net(x)
        assert out.shape == (1, mnist_config['num_classes'])
    print("test_informal_kl_div_pruning: PASSED")

def test_formal_pruning_winner_runner():
    # use a real mnist sample and the saved mnist model from dataset_utils
    exp_paths = {
        'pruned_net_path': "tests/tmp_pruned.pth",
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[SAMPLE_ID]
    )
    tmp_path = "tests/tmp_mnist.pth"
    torch.save(model.state_dict(), tmp_path)
    pruned_net, active, timeouts = formal_prune_mnist(
        dataset="mnist",
        full_net=model,
        X=X,
        full_net_path=model_path,
        device=device,
        prune_by='neurons',
        metric='winner_runner',
        patching='zero',
        epsilon=EPSILON,
        delta=DELTA,
        exp_paths=exp_paths,
        query_timeout=QUERY_TIMEOUT
    )
    # Save the formal pruned model for later tests
    torch.save(pruned_net.state_dict(), FORMAL_PRUNED_PATH)
    with torch.no_grad():
        out = pruned_net(X)
        assert out.shape == (1, mnist_config['num_classes'])
    # check the criterion holds for 100 perturbations in the epsilon neighbourhood
    perturbed = generate_perturbations(X, EPSILON, num_samples=100)
    batch_size = 25
    with torch.no_grad():
        for i in range(0, perturbed.shape[0], batch_size):
            batch = perturbed[i:i+batch_size]
            orig = model(batch)
            pruned = pruned_net(batch)
            for j in range(batch.shape[0]):
                assert winner_runner_criterion(orig[j:j+1], pruned[j:j+1], DELTA), f"Formal pruning failed winner_runner criterion on perturbed sample {i+j}"
    print("test_formal_pruning_winner_runner: PASSED")

def test_formal_adv_example_detection():
    exp_paths = {
        'pruned_net_path': INF_PRUNED_PATH,
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    model, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[SAMPLE_ID]
    )
    x = X
    tmp_path = "tests/tmp_mnist.pth"

    is_safe, ver_res = find_formal_adv_example(
        dataset='mnist',
        pruned_net_path=INF_PRUNED_PATH,
        full_net_path=tmp_path,
        X=x,
        adv_x_path=exp_paths['adv_x_path'],
        metric='winner_runner',
        device=device,
        epsilon=EPSILON,
        delta=DELTA,
        exp_paths=exp_paths,
        query_timeout=QUERY_TIMEOUT
    )
    assert isinstance(is_safe, bool), "Expected is_safe to be a boolean"
    assert is_safe==False, "Expected informal net to be non-robust"
    adv_x = load_x_from_file("mnist", exp_paths['adv_x_path'])
    informal_pruned_net_loaded = load_mnist_model(INF_PRUNED_PATH, device)
    with torch.no_grad():
        orig = model(adv_x)
        pruned = informal_pruned_net_loaded(adv_x)
        assert not winner_runner_criterion(orig, pruned, DELTA), "Adversarial example does not break the criterion"
    print("test_formal_adv_example_detection: PASSED")

def test_formal_pruned_small_large_perturbation():
    pruned_net_path = FORMAL_PRUNED_PATH
    pruned_net = load_mnist_model(pruned_net_path, device)
    # Load the original (unpruned) model for comparison
    exp_paths = {
        'pruned_net_path': "tests/tmp_pruned.pth",
        'path_to_save_mnist_dupnet': "tests/tmp_dupnet.pth",
        'customized_models_paths': config['paths']['customized_models_paths'],
        'adv_x_path': "tests/tmp_adv_x.txt",
        'mnist_sample_path': "tests/tmp_sample.npy",
        'abcrown_specification_path': "tests/tmp_spec.yaml",
        'abcrown_run_path': config['paths']['abcrown_run_path'],
        'abcrown_log_dir': "tests/abcrown_logs"
    }
    # Load the original model using load_dataset_and_model to get full_net
    full_net, model_path, train_gen, test_gen, test_data, batch_path, X, winner_runner_logit_diff = load_dataset_and_model(
        "mnist", device, exp_paths, sample_ids=[SAMPLE_ID]
    )
    small_epsilon = EPSILON
    small_perturbed = generate_perturbations(X, small_epsilon, num_samples=100)
    batch_size = 25
    with torch.no_grad():
        all_orig = []
        all_pruned = []
        for i in range(0, small_perturbed.shape[0], batch_size):
            batch = small_perturbed[i:i+batch_size]
            orig = full_net(batch)
            pruned = pruned_net(batch)
            all_orig.append(orig)
            all_pruned.append(pruned)
        orig = torch.cat(all_orig, dim=0)
        pruned = torch.cat(all_pruned, dim=0)
        assert winner_runner_criterion(orig, pruned, DELTA), "Formal pruned model failed small epsilon criterion"
    print("Formal pruned model: small epsilon perturbation test PASSED")
    # Large epsilon test
    large_epsilon = 0.95
    large_perturbed = generate_perturbations(X, large_epsilon, num_samples=100)
    batch_size = 25
    with torch.no_grad():
        all_orig = []
        all_pruned = []
        for i in range(0, large_perturbed.shape[0], batch_size):
            batch = large_perturbed[i:i+batch_size]
            orig = full_net(batch)
            pruned = pruned_net(batch)
            all_orig.append(orig)
            all_pruned.append(pruned)
        orig = torch.cat(all_orig, dim=0)
        pruned = torch.cat(all_pruned, dim=0)
        found_break = not winner_runner_criterion(orig, pruned, DELTA)
        if found_break:
            print("True")
    assert found_break, "No sample broke the criterion for large epsilon"


if __name__ == "__main__":
    test_informal_pruning_winner_runner()
    test_formal_pruning_winner_runner()
    test_formal_adv_example_detection()
    test_formal_pruned_small_large_perturbation()
