import numpy as np
import random
import torch

from datasets.utils import *
from modelZoo.utils import get_model
from utils import *


def one_ood_model_result(OOD_model, device, l, model_name, epsilon, total_batches, batch_size, steps=10,
                         eps_initialize=None):
    in_dist_help = None
    for i in range(1, len(l)):
        name_out, _ = l[i]
        name_in, _ = l[0]
        print("*" * 20, " model = ", model_name, " in= ", name_in, " out= ", name_out, " ood = ",
              OOD_model.__class__.__name__, "steps=", steps, "batch_size=", batch_size, "eps=", epsilon, "*" * 20)
        clean_auc, in_auc, out_auc, both_auc, in_dist_help = attack_and_plot(OOD_model, device, l[0], l[i], epsilon,
                                                                             (epsilon / steps) * 2.5, steps, 1,
                                                                             print_auc=False,
                                                                             total_batches=total_batches,
                                                                             in_dist_help=in_dist_help,
                                                                             eps_initialize=eps_initialize)
        print()
        print()
        print()
        print("RESULT", clean_auc)
        print("RESULT", in_auc)
        print("RESULT", out_auc)
        print("RESULT", both_auc)
        print()
        print()
        print()


def run_experiment(epsilon, model_name, dataset_in_name, dataloader_train_in, seed, image_size=32, batch_size=128,
                   adv_model=True, total_batches=15):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    similar_out = "cifar10" if dataset_in_name == "cifar100" else "cifar100"

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    l = get_outdist_dataloaders(
        [dataset_in_name, 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers',
         'coil_100'], image_size=image_size, batch_size=batch_size)
    model = get_model(model_name, dataset_in_name, device)
    train_name, train_dataloader, train_class = dataloader_train_in
    if train_name == "Tiny imagenet":
        get_clean_accuracy_on_test(model, device, (train_name, train_dataloader), total_batches=80)
        get_adversarial_accuracy_on_test(model, device, (train_name, train_dataloader), epsilon, (epsilon / 10) * 2.5,10, 1,
                                         total_batches=80)
    else:
        get_clean_accuracy_on_test(model, device, l[0], total_batches=total_batches)
        get_adversarial_accuracy_on_test(model, device, l[0], epsilon, (epsilon / 10) * 2.5, 10, 1,
                                         total_batches=total_batches)

    MSP_model = OOD_MSP(model)
    one_ood_model_result(MSP_model, device, l, model_name, epsilon, total_batches, batch_size)

    maha_model = OOD_maha_dist(model, dataloader_train_in, device)
    one_ood_model_result(maha_model, device, l, model_name, epsilon, total_batches, batch_size)

    rel_maha_model = OOD_rel_maha_dist(model, dataloader_train_in, device)
    one_ood_model_result(rel_maha_model, device, l, model_name, epsilon, total_batches, batch_size)

    openmax_model = OOD_openMax(model, dataloader_train_in, device, robust_model=adv_model)
    one_ood_model_result(openmax_model, device, l, model_name, epsilon, total_batches, batch_size)




def run_experiment_ATOM(epsilon, model_name, dataset_in_name, dataloader_train_in, seed, image_size=32, batch_size=128,
                        adv_model=True, total_batches=15):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    similar_out = "cifar10" if dataset_in_name == "cifar100" else "cifar100"

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    l = get_outdist_dataloaders(
        [dataset_in_name, 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers',
         'coil_100'], image_size=image_size, batch_size=batch_size)
    model = get_model(model_name, dataset_in_name, device)
    # get_clean_accuracy_on_test(model, device, l[0], total_batches=total_batches)
    #
    # get_adversarial_accuracy_on_test(model, device, l[0], epsilon, (epsilon / 10) * 2.5, 10, 1,
    #                                  total_batches=total_batches)

    ATOM_model = OOD_ATOM(model)
    one_ood_model_result(ATOM_model, device, l, model_name, epsilon, total_batches, batch_size, 100)
