import argparse

import torch
from torch.utils.data import DataLoader, TensorDataset

from datasets import get_dataset, get_dataset_mean_var
from model_mimic_attack_utils import random_dataset_subset, full_model_mimic_attack, attack_vary, show_images_diff
from models_zoo import resnet101, resnet50, resnet34, resnet18, lipschitz_cnn, small_cnn
from nn_model_manager import NNModelManager


def return_model_by_name(model_name, num_classes, **kwargs):
    if model_name == "resnet18":
        model = resnet18(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet34":
        model = resnet34(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet50":
        model = resnet50(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet101":
        model = resnet101(pretrained=False, num_classes=num_classes)
    elif model_name == "lipschitz_cnn":
        model = lipschitz_cnn(num_classes=num_classes, **kwargs)
    elif model_name == "small_cnn":
        model = small_cnn(num_classes=num_classes)
    else:
        raise Exception("Choose another model")
    return model


def test_attack(nn_mm, target_image, target_label, mean_dataset, std_dataset, train_data):
    coefficient = 4
    n_attack_samples = 25
    adv_dataset, adversarial_images_for_train, successful_adv_samples_distance = attack_vary(
        model=nn_mm.model,
        target_image=target_image,
        train_dataset=train_data,
        target_label=target_label,
        epsilon_min_bound=0.01 * coefficient,
        epsilon_max_bound=0.1 * coefficient,
        n_attack_samples=n_attack_samples,
        mean_dataset=mean_dataset,
        std_dataset=std_dataset,
        attack_type="pgd",
        alpha=0.01,
        num_iter=10,
    )
    print(f"\nMax distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.max()}.\n"
          f"Min distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.min()}.\n"
          f"Mean distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.mean()}.\n"
          f"Var distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.var()}.\n"
          )
    min_index = torch.argmin(successful_adv_samples_distance)
    max_index = torch.argmax(successful_adv_samples_distance)
    show_images_diff(target_image[0][0], adv_dataset.tensors[0][max_index:max_index + 1], mean_dataset, std_dataset)
    show_images_diff(target_image[0][0], adv_dataset.tensors[0][min_index:min_index + 1], mean_dataset, std_dataset)


def test():
    available_models = ["small_cnn", "lipschitz_cnn", "resnet18", "resnet34", "resnet50", "resnet101"]
    dataset_name2num_classes = {
        "mnist": 10,
        "cifar10": 10,
        "cifar100": 100
    }
    available_datasets = list(dataset_name2num_classes)

    parser = argparse.ArgumentParser()
    parser.add_argument("--blackbox-model", help="name of black box model (default: %(default)s)", choices=available_models, default=available_models[2])
    parser.add_argument("--whitebox-model", help="name of white box model (mimic) (default: %(default)s)", choices=available_models, default=available_models[0])
    parser.add_argument("--dataset-name", help="name of the dataset (default: %(default)s)", choices=available_datasets, default=available_datasets[2])
    parser.add_argument("--num-epochs", help="number of attack iterations (default: %(default)d)", type=int, default=180)
    parser.add_argument("--subset-size", help="size of the subset dataset (default: %(default)d)", type=int, default=600)
    parser.add_argument("--n-attack-samples", help="number of attack samples (default: %(default)d)", type=int, default=30)
    parser.add_argument("--batch-size", help="size of batch (default: %(default)d)", type=int, default=256)
    args = parser.parse_args()

    batch_size = args.batch_size
    num_epochs = args.num_epochs
    # Black box model name (resnet18, resnet34, resnet50 or resnet101)
    model_name = args.blackbox_model

    # Dataset name and num classes in this dataset (mnist, cifar10 or cifar100)
    # dataset_name = "mnist"
    dataset_name = args.dataset_name
    # dataset_name = "cifar100"
    num_classes = dataset_name2num_classes[dataset_name]
    # num_classes = 100

    # White box model name (resnet18, resnet34, resnet50 or resnet101)
    # (It is worth taking a model that is simpler than a black box model)
    # model_mimic_name = "resnet18"
    model_mimic_name = args.whitebox_model

    # Start (random) subset parameters
    subset_size = args.subset_size
    # test_part = 0
    test_part = 0

    # attack vary parameters, min and max image perturbation and number of attack samples
    # coefficient = 2.5
    coefficient = 3
    n_attack_samples = args.n_attack_samples
    epsilon_min_bound = 0.01 * coefficient
    epsilon_max_bound = 0.1 * coefficient

    # PGD attack parameters
    alpha = 0.01
    num_iter = 10

    #
    # show_image_flag = True
    show_image_flag = False

    # n_components parameter PCA
    # n_components = 100
    n_components = None

    # Flag to project all attacking examples onto the span
    # after receiving new attacks or not
    span_project_flag = False
    # span_project_flag = True

    first_train_mimic_model_on_test_part_flag = True
    # first_train_mimic_model_on_test_part_flag = False

    print(torch.cuda.is_available())
    # device = torch.device('cuda')
    device = "cpu"
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Available Device ', device)

    train_data, test_data = (get_dataset(dataset=dataset_name, split="train"),
                             get_dataset(dataset=dataset_name, split="test"))
    mean_dataset, std_dataset = get_dataset_mean_var(dataset=dataset_name)

    model = return_model_by_name(model_name=model_name, num_classes=num_classes)

    # data = torch.load(model_path)['model']
    # model.to(device)
    # model.load_state_dict(data)

    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    nn_mm = NNModelManager(
        dataset_name=dataset_name, batch=batch_size,
        model_name=model_name, model=model,
        device=device, num_epochs=num_epochs,
    )
    try:
        # raise FileNotFoundError()
        nn_mm.save_name = 0
        nn_mm.load()
    except FileNotFoundError:
        nn_mm.save_name = None
        nn_mm.train_model(dataset=(train_data, test_data))
    # acc = nn_mm.evaluate_model(dataloader=test_dataloader)
    # print(acc)

    model_mimic = return_model_by_name(model_name=model_mimic_name, num_classes=num_classes)

    if first_train_mimic_model_on_test_part_flag:
        train_subset, test_subset, target_image, target_label = random_dataset_subset(
            dataset=test_data, subset_size=subset_size, test_part=test_part
        )
    else:
        train_subset, test_subset, target_image, target_label = random_dataset_subset(
            dataset=train_data, subset_size=subset_size, test_part=test_part
        )
    test_subset_dataloader = DataLoader(test_subset, batch_size=batch_size, shuffle=True)
    nn_mm_mimic = NNModelManager(
        dataset_name=dataset_name, batch=batch_size,
        model_name=model_mimic_name, model=model_mimic,
        device=device, num_epochs=num_epochs,
    )

    full_model_mimic_attack(
        nn_mm_mimic=nn_mm_mimic, nn_mm=nn_mm,
        train_subset=train_subset, test_subset=test_subset,
        test_subset_dataloader=test_subset_dataloader, target_image=target_image,
        target_label=target_label, mean_dataset=mean_dataset,
        std_dataset=std_dataset, batch_size=batch_size, n_components=n_components,
        span_project_flag=span_project_flag, show_image_flag=show_image_flag,
        n_attack_samples=n_attack_samples,
        epsilon_min_bound=epsilon_min_bound, epsilon_max_bound=epsilon_max_bound,

        alpha=alpha, num_iter=num_iter,
    )

    # test_attack(nn_mm, target_image, target_label, mean_dataset, std_dataset, train_data)


if __name__ == "__main__":
    test()
