import sys

from datasets.utils import *
from modelZoo.utils import get_model
from utils import *


def one_ood_model_result_var_eps(OOD_model, device, l, model_name, total_batches, steps):
    name_in, _ = l[0]
    for j in range(9):
        epsilon = j / 255

        cleanavg = 0
        inavg = 0
        outavg = 0
        inoutavg = 0
        in_dist_help=None
        for i in range(1, len(l)):
            name_out, _ = l[i]
            clean_auc, in_auc, out_auc, both_auc, in_dist_help = attack_and_plot(OOD_model, device, l[0], l[i], epsilon,
                                                                   (epsilon / steps) * 2.5, steps, 1,
                                                                   print_auc=False, total_batches=total_batches, in_dist_help=in_dist_help)
            cleanavg += clean_auc
            inavg += in_auc
            outavg += out_auc
            inoutavg += both_auc

        print("*" * 20, " model = ", model_name, " in= ", name_in, " ood = ",
              OOD_model.__class__.__name__, "eps=", epsilon, "steps=", steps, "*" * 20)
        print("total result claen", cleanavg / 8)
        print("total result in", inavg / 8)
        print("total result out", outavg / 8)
        print("total result both", inoutavg / 8)


if __name__ == "__main__":
    model_names = ["open-set", "ALOE", "ViT-L_32", "adversarial_train_Madry", "ATOM"]
    OOD_method_names = ["OpenMax", "MSP", "MSP", "MD", "ATOM"]
    id = int(sys.argv[1]) - 1

    model_name = model_names[id]
    OOD_name = OOD_method_names[id]
    print(f"run for model:{model_name}, with OOD detection method:{OOD_name}")

    image_size = 32 if model_name != "ViT-L_32" else 224
    batch_size = 128 if model_name != "ViT-L_32" else 32
    adv_model = model_name != "ViT-L_32"
    total_batches = 15 if model_name != "ViT-L_32" else 15 * 4

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    trainloader_in = get_trainloader_cifar100(image_size=image_size, batch_size=batch_size)
    dataset_in_name = "cifar100"
    l = get_outdist_dataloaders(
        [dataset_in_name, 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers', 'coil_100'],
        image_size=image_size, batch_size=batch_size, num_workers=1)

    model = get_model(model_name, dataset_in_name, device)

    seed = 2022

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    if OOD_name == "MSP":
        OOD_model = OOD_MSP(model)
    elif OOD_name == "MD":
        OOD_model = OOD_maha_dist(model, trainloader_in, device)
    elif OOD_name == "RMD":
        OOD_model = OOD_rel_maha_dist(model, trainloader_in, device)
    elif OOD_name == "OpenMax":
        OOD_model = OOD_openMax(model, trainloader_in, device, True)
    elif OOD_name == "ATOM":
        OOD_model = OOD_ATOM(model)
    else:
        raise ValueError

    one_ood_model_result_var_eps(OOD_model, device, l, model_name, total_batches, 100)

    # paper name # code name
    # ViT       # ViT-L_32
    # HAT       # Rade2021Helper_R18_extra
    # OSAD      # open-set
    # AT        # adversarial_train_Madry
    # ALOE      # ALOE
    # AOE       # AOE
