import os
import sys

import torchvision.transforms as transforms
import torch

sys.path.append(".")

from image_uncertainty.cifar.cifar_evaluate import described_plot
from experiments.imagenet_discrete import test_loaders, normalize, parse_args, dump_ues
from spectral_normalized_models.ddu import (
    gmm_fit, gmm_evaluate, get_embeddings, logsumexp
)

from imagenet.models import load_model
from imagenet import dataset_image


def get_train_loader(data_dir=None, batch_size=32, subsample=False):
    if data_dir is None:
        data_dir = '/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012/'
    traindir = os.path.join(data_dir, "train")

    train_dataset = dataset_image.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    if subsample:
        every_other = range(0, len(train_dataset), 20)
        train_dataset = torch.utils.data.Subset(train_dataset, every_other)

    return torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )


def main():
    args = parse_args()
    print(args)

    val_loader, ood_loader = test_loaders(args.data_folder, args.ood_folder, args.b, subsample=True)
    train_loader = get_train_loader(batch_size=args.b, subsample=True)
    print('training len', len(train_loader)*args.b)
    model = load_model('spectral')

    model.eval()
    device = 'cuda'
    storage_device = 'cpu'

    embeddings, labels = get_embeddings(
        model,
        train_loader,
        num_dim=2048,
        dtype=torch.double,
        device=device,
        storage_device=storage_device,
    )

    gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=1000)

    measure = logsumexp  # logsumexp or entropy

    ues = measure(gmm_evaluate(
        model, gaussians_model, val_loader, device=device, num_classes=1000, storage_device=storage_device,
    )[0])
    ood_ues = measure(gmm_evaluate(
            model, gaussians_model, ood_loader, device='cuda', num_classes=1000, storage_device='cpu',
    )[0])

    described_plot(
        -ues, -ood_ues, args.ood_name, args.net, f'DDU'
    )
    dump_ues(-ues, -ood_ues, f'ddu', 'imagenet', args.ood_name)


if __name__ == '__main__':
    main()
