import os
import sys
import numpy as np

import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt

sys.path.append(".")

from image_uncertainty.cifar.cifar_evaluate import described_plot
from experiments.imagenet_discrete import test_loaders, normalize, parse_args, dump_ues
from spectral_normalized_models.ddu import (
    gmm_fit, gmm_evaluate, get_embeddings, logsumexp
)

from imagenet.models import load_model
from imagenet import dataset_image


def get_train_loader(data_dir=None, batch_size=32, subsample=False):
    if data_dir is None:
        data_dir = '/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012/'
    traindir = os.path.join(data_dir, "train")

    train_dataset = dataset_image.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    if subsample:
        every_other = range(0, len(train_dataset), 200)
        train_dataset = torch.utils.data.Subset(train_dataset, every_other)

    return torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

def denormalize(image):
    img = image * 0.225 + 0.45
    return img

names = eval(open('imagenet/class_names.txt', 'r').read())


def show(image_tensor, label=None):
    if label:
        print(label)
        if type(label) == int:
            plt.title(names[label])
        else:
            plt.title(label)

    plt.imshow(np.moveaxis(denormalize(image_tensor).numpy(), 0, 2))
    plt.show()



def main():
    args = parse_args()

    val_loader, ood_loader = test_loaders(args.data_folder, args.ood_folder, args.b, subsample=False)
    model = load_model('spectral')
    model.eval()
    device = 'cuda'
    storage_device = 'cpu'
    mode = 'evaluate'


    if mode == 'generate':
        train_loader = get_train_loader(batch_size=args.b, subsample=False)
        print('training len', len(train_loader.dataset))
        embeddings, labels = get_embeddings(
            model,
            train_loader,
            num_dim=2048,
            dtype=torch.double,
            device=device,
            storage_device=storage_device,
        )

        torch.save(embeddings.cpu(), 'x_train_embeddings.pt')
        torch.save(labels.cpu(), 'y_train_embeddings.pt')

    else:
        # x_train_embeddings = torch.load("checkpoint/x_train_embeddings_subsampled.pt")
        # labels = torch.load('checkpoint/y_train_embeddings_subsampled.pt')
        x_train_embeddings = torch.load("checkpoint/x_train_embeddings.pt")
        labels = torch.load('checkpoint/y_train_embeddings.pt')
        every_other = range(0, len(x_train_embeddings), 20)
        x_train_embeddings = x_train_embeddings[every_other]
        labels = labels[every_other]

        print('Accuracy', np.mean((
            torch.argmax(
                model.linear(x_train_embeddings[:1000].float().cuda()), dim=-1).cpu() == labels[:1000]
            ).numpy()
        ))

        embeddings = x_train_embeddings
        print("here")
        gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=1000)
        print('here2')

        measure = logsumexp  # logsumexp or entropy

        ues = measure(gmm_evaluate(
            model, gaussians_model, val_loader, device=device, num_classes=1000, storage_device=storage_device,
        )[0])
        ood_ues = measure(gmm_evaluate(
                model, gaussians_model, ood_loader, device='cuda', num_classes=1000, storage_device='cpu',
        )[0])
        #
        described_plot(
            -ues, -ood_ues, args.ood_name, args.net, f'DDU'
        )
        dump_ues(-ues, -ood_ues, f'ddu', 'imagenet', args.ood_name)


if __name__ == '__main__':
    main()
