from datasets.utils import *
from modelZoo.utils import get_model
from utils import *
import numpy as np  # v 1.19.2


def attack_classification_to_OOD(OOD_model, model, device, testloader_in, testloader_out, epsilon, alpha, attack_iters,
                                 restarts, total_batches=15):
    name_in, testloader_in = testloader_in
    name_out, testloader_out = testloader_out

    in_dist, out_dist, in_dist_attacked = [], [], []
    for step, data in enumerate(tqdm(testloader_in, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        delta = attack_pgd_classification(model, images, labels,
                                          epsilon,
                                          alpha,
                                          attack_iters,
                                          restarts, "l_inf")

        adv_dist = OOD_model(images + delta).detach()
        in_dist_attacked.append(adv_dist.cpu())

        dist = OOD_model(images).detach()
        in_dist.append(dist.cpu())

    for step, data in enumerate(tqdm(testloader_out, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        dist = OOD_model(images).detach()
        out_dist.append(dist.cpu())

    in_dist = torch.cat(in_dist).cpu().numpy()
    out_dist = torch.cat(out_dist).cpu().numpy()
    in_dist_attacked = torch.cat(in_dist_attacked).cpu().numpy()

    ######################## indist(clean)   outdist(clean)   ############
    onehots = np.array([1] * out_dist.shape[0] + [0] * in_dist.shape[0])
    scores = np.concatenate([out_dist, in_dist], axis=0)
    auroc_clean = roc_auc_score(onehots, scores)
    ######################## indist(attacked)  outdist(clean)   ############
    onehots = np.array([1] * out_dist.shape[0] + [0] * in_dist_attacked.shape[0])
    scores = np.concatenate([out_dist, in_dist_attacked], axis=0)
    auroc_in = roc_auc_score(onehots, scores)

    return auroc_clean, auroc_in


def attack_OOD_to_classificatoin(OOD_model, model, device, testloader_in, epsilon, alpha, attack_iters,
                                 restarts, total_batches=15):
    name_in, testloader_in = testloader_in
    correct = 0
    total = 0
    for step, data in enumerate(tqdm(testloader_in, total=total_batches, leave=False)):
        if step == total_batches:
            break
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        delta = attack_pgd_ood_detection(OOD_model, images, torch.zeros(images.shape[0], device=device),
                                         epsilon, alpha, attack_iters, restarts, "l_inf")
        outputs = model(images + delta)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on %s with OOD attacked examples, epsilon=%.4f : %.3f %%' % (
        name_in, epsilon, 100 * correct / total))


if __name__ == "__main__":
    model_name = "open-set"

    image_size = 32
    batch_size = 128

    seed = 2022
    dataset_name = "cifar10"
    dataloader_train_in = get_trainloader_cifar10(image_size=image_size, batch_size=batch_size)
    class_num = 10

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    l = get_outdist_dataloaders(
        ["cifar10", 'mnist', 'tiny_imagenet', 'places365', 'LSUN', 'iSUN', 'birds', 'flowers', 'coil_100'],
        image_size=image_size, batch_size=batch_size)

    model = get_model(model_name, dataset_name, device)
    OOD_model = OOD_openMax(model, dataloader_train_in, device)

    get_clean_accuracy_on_test(model, device, l[0], total_batches=15)
    get_adversarial_accuracy_on_test(model, device, l[0], 8 / 255, ((8 / 255) / 10) * 2.5, 10, 1,
                                     total_batches=15)

    sumadv = 0
    name_in, _ = l[0]

    for i in range(1, len(l)):
        clean_auc, in_auc = attack_classification_to_OOD(OOD_model, model, device, l[0], l[i], 8 / 255,
                                                         ((8 / 255) / 10) * 2.5, 10, 1)
        sumadv += in_auc

    print("*" * 20, " model = ", model_name, " in= ", name_in, " ood = ",
          OOD_model.__class__.__name__, "*" * 20)
    print("Detection AUROC from classification attack:", sumadv/8)

    attack_OOD_to_classificatoin(OOD_model, model, device, l[0], 8 / 255, ((8 / 255) / 10) * 2.5, 10, 1,
                                 total_batches=15)
