import numpy as np
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

from datasets import NIPS2017AdversarialCompetition
from constants import BATCH_SIZE, DEVICE
from models import models, sizes


def generate_predict_labels():
    # transforms
    transforms = T.Compose([
        T.ToTensor(),
        T.Lambda(lambda img: img * 2.0 - 1.0)
        # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # dataset
    dataset = NIPS2017AdversarialCompetition(transform=transforms, requires_grad=True)
    dataLoader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    #
    predict_labels = [[] for _ in range(len(models))]

    for images, labels in dataLoader:
        images = images.to(DEVICE)
        # labels = labels.to(DEVICE)

        with torch.no_grad():
            for index, (model_name, target_model) in enumerate(models.items()):
                if sizes[model_name] == 299:
                    images_resized = images.clone()
                else:
                    images_resized = T.Resize([sizes[model_name], ])(images)
                logits = target_model(images_resized)
                predicts = torch.max(logits, dim=1)[1]
                for pre in predicts:
                    predict_labels[index].append(pre.item())

    predict_labels_tensor = torch.from_numpy(np.array(predict_labels))
    torch.save(predict_labels_tensor, 'resources/ground_truth/predict_labels2.pt')
