import sys
import torch
import numpy as np
from tqdm import tqdm

sys.path.append(".")
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test,
    misclassification_detection
)
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader
from image_uncertainty.cifar import settings
from nuq import NuqClassifier, get_kernel
from experiments.imagenet_discrete import dump_ues


def get_embeddings(model, loader):
    labels = []
    embeddings = []
    for i, (images, batch_labels) in enumerate(tqdm(loader)):
        with torch.no_grad():
            if args.gpu:
                images = images.cuda()
            embeddings.append(model(images).cpu().numpy())
            # model(images)
            # embeddings.append(model.feature.cpu().numpy())
        labels.extend(batch_labels.tolist())
        # if i == 5:
        #     break

    return np.concatenate(embeddings), np.array(labels)


def main(args):
    train_loader, val_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True,
        ood_name=args.ood_name,
        seed=args.data_seed
    )

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    model = load_model(args.net, args.weights, args.gpu)
    model.eval()

    x_train, y_train = get_embeddings(model, train_loader)
    x_test, y_test = get_embeddings(model, test_loader)
    x_ood, y_ood = get_embeddings(model, ood_loader)

    nuq = NuqClassifier(strategy="isj", tune_bandwidth=True, n_neighbors=80, precise_computation=True)
    nuq.fit(X=x_train, y=y_train)

    ues_test = nuq.predict_uncertainty(x_test)
    ues_ood = nuq.predict_uncertainty(x_ood)
    try:
        for ue_type in ['epistemic', 'total', 'aleatoric']:
            print(ue_type)
            described_plot(
                ues_test[ue_type], ues_ood[ue_type], args.ood_name, args.net,
                title_extras='Nadaray-Watson'
            )
    except:
        import ipdb; ipdb.set_trace()


if __name__ == '__main__':
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    main(args)
