import sys

import torch
import numpy as np

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test, misclassification_detection
)
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader, settings
from spectral_normalized_models.ddu import (
    gmm_fit, gmm_evaluate, get_embeddings, logsumexp, entropy
)
from experiments.imagenet_discrete import dump_ues


# sys.argv = 'experiments/cifar_nuq.py --gpu --ood-name=smooth --data-seed=42 --net=resnet50_spectral'.split()
def main():
    return
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    args.cached = False

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    train_loader, val_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True,
        ood_name=args.ood_name,
        seed=args.data_seed
    )

    if not args.cached:
        model = load_model(args.net, args.weights, args.gpu)
        model.eval()

        device = 'cuda'
        storage_device = 'cpu'

        embeddings, labels = get_embeddings(
            model,
            train_loader,
            num_dim=2048,
            dtype=torch.double,
            device=device,
            storage_device=storage_device,
        )
        print(labels.shape)

        gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=100)
        logits, labels = gmm_evaluate(
            model, gaussians_model, test_loader, device=device, num_classes=100, storage_device=storage_device,
        )
        ood_logits, ood_labels = gmm_evaluate(
            model, gaussians_model, ood_loader, device='cuda', num_classes=100, storage_device='cpu',
        )

    else:
        with open('t_x_train.npy', 'rb') as f:
            x_train = np.load(f)
        with open('t_y_train.npy', 'rb') as f:
            y_train = np.load(f)
        with open('t_x_test.npy', 'rb') as f:
            x_test = np.load(f)
        with open('t_x_ood.npy', 'rb') as f:
            x_ood = np.load(f)

        embeddings = torch.tensor(x_train)
        labels = torch.tensor(y_train)
        gaussians_model, jitter_eps = gmm_fit(embeddings=embeddings, labels=labels, num_classes=100)

        logits = gaussians_model.log_prob(torch.tensor(x_test)[:, None, :].float())
        ood_logits = gaussians_model.log_prob(torch.tensor(x_ood)[:, None, :].float())


    method = logsumexp
    ues = -method(logits)
    ues_ood = -method(ood_logits)

    described_plot(
        ues, ues_ood, args.ood_name, args.net, f'DDU'
    )
    dump_ues(ues, ues_ood, f'ddu_{args.data_seed}', 'cifar', args.ood_name)


if __name__ == '__main__':
    main()
