from datasets.utils import *
from modelZoo.utils import get_model
from utils import *
import torch, random
import numpy as np


def step_on_auc(OOD_model, dataset_in_name, seed, image_size=32,
                batch_size=128, total_batches=15):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    l = get_outdist_dataloaders(
        [dataset_in_name, 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers',
         'coil_100'], image_size=image_size, batch_size=batch_size)
    name_in, _ = l[0]
    for step in [1, 5, 10, 25, 50, 100, 200]:
        sum_in_auc = 0
        sum_out_auc = 0
        sum_both_auc = 0
        for i in range(1, len(l)):
            _, in_auc, out_auc, both_auc, _ = attack_and_plot(OOD_model, device, l[0], l[i], (8 / 255),
                                                           ((8 / 255) / step) * 2.5, step, 1,
                                                           print_auc=False, total_batches=total_batches)
            sum_in_auc += in_auc
            sum_out_auc += out_auc
            sum_both_auc += both_auc
        print("*" * 20, " model = ", OOD_model.model.__class__.__name__, " in= ", name_in, " ood = ",
              OOD_model.__class__.__name__, "steps=", step, "*" * 20)
        print("average in_auc ", sum_in_auc/8)
        print("average out_auc", sum_out_auc/8)
        print("average both_auc", sum_both_auc/8)

if __name__ == "__main__":
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 2022
    dataset_name = "cifar10"
    dataloader_train_in = get_trainloader_cifar10()

    ALOE_MSP = OOD_MSP(get_model("Rade2021Helper_R18_extra", "cifar10", device))
    step_on_auc(ALOE_MSP, dataset_name, seed)

    OSAD = OOD_openMax(get_model("open-set", "cifar10", device), get_trainloader_cifar10(), device, robust_model=True)
    step_on_auc(OSAD, dataset_name, seed)
