import os
import sys
import numpy as np
from pathlib import Path

import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

sys.path.append(".")

from image_uncertainty.cifar.cifar_evaluate import described_plot
from experiments.imagenet_discrete import test_loaders, normalize, parse_args
from image_uncertainty.spectral_normalized_models import (
    gmm_fit, get_embeddings, logsumexp
)

from image_uncertainty.imagenet.models import load_model
from image_uncertainty.imagenet import dataset_image


def get_train_loader(data_dir=None, batch_size=32, subsample=False):
    if data_dir is None:
        data_dir = '/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012/'
    traindir = os.path.join(data_dir, "train")

    train_dataset = dataset_image.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    if subsample:
        every_other = range(0, len(train_dataset), 200)
        train_dataset = torch.utils.data.Subset(train_dataset, every_other)

    return torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

def denormalize(image):
    img = image * 0.225 + 0.45
    return img

names = eval(open('imagenet/class_names.txt', 'r').read())


def show(image_tensor, label=None):
    if label:
        print(label)
        if type(label) == int:
            plt.title(names[label])
        else:
            plt.title(label)

    plt.imshow(np.moveaxis(denormalize(image_tensor).numpy(), 0, 2))
    plt.show()



def main():
    args = parse_args()

    device = 'cuda'
    storage_device = 'cpu'
    mode = 'evaluate'

    if mode == 'generate':
        val_loader, ood_loader = test_loaders(
            args.data_folder, args.ood_folder, args.b, subsample=False
        )
        model = load_model('spectral')
        model.eval()
        train_loader = get_train_loader(batch_size=args.b, subsample=False)
        print('training len', len(train_loader.dataset))
        embeddings, labels = get_embeddings(
            model,
            train_loader,
            num_dim=2048,
            dtype=torch.double,
            device=device,
            storage_device=storage_device,
        )

        torch.save(embeddings.cpu(), 'x_train_embeddings.pt')
        torch.save(labels.cpu(), 'y_train_embeddings.pt')

    else:
        base_dir = Path('/home/mephody_bro/imagenet_embeddings_full')

        num_classes = 1000

        x_train = np.load(str(base_dir / f'train_embeddings.npy'))
        y_train = np.load(str(base_dir / f'train_targets.npy'))
        x_val = np.load(str(base_dir / f'val_embeddings.npy'))
        y_val = np.load(str(base_dir / f'val_targets.npy'))
        x_ood = np.load(str(base_dir / f'ood_embeddings_{args.ood_name}.npy'))

        every_other = range(0, len(x_train), 20)
        x_train = x_train[every_other]
        y_train = y_train[every_other]

        num_classes = 1000
        x_train = x_train[y_train < num_classes]
        y_train = y_train[y_train < num_classes]
        x_val = x_val[y_val < num_classes]
        y_val = y_val[y_val < num_classes]


        print(x_train.shape)

        ## Sanity check
        # from sklearn.neighbors import KNeighborsClassifier
        # from sklearn.preprocessing import StandardScaler
        #
        # scaler = StandardScaler()
        # x_train_ = scaler.fit_transform(x_train)
        # x_val_ = scaler.transform(x_val)
        #
        # model = KNeighborsClassifier(50, weights='distance')
        # model.fit(x_train_, y_train)
        # print('Accuracy k-NN', accuracy_score(y_val, model.predict(x_val_)))
        x_train = torch.tensor(x_train)
        y_train = torch.tensor(y_train)
        x_val = torch.tensor(x_val)
        y_val = torch.tensor(y_val)
        x_ood = torch.tensor(x_ood)


        print("Start fitting GMM")
        gaussians_model, jitter_eps = gmm_fit(
            embeddings=x_train,
            labels=y_train,
            num_classes=num_classes
        )
        # gaussians_model = gaussians_model

        def batch(iterable, n=1):
            l = len(iterable)
            for ndx in range(0, l, n):
                yield iterable[ndx:min(ndx + n, l)]

        def eval_embeddings(gmm, embeddings, batch_size=50):
            log_probs = None
            for x in tqdm(batch(embeddings, batch_size)):
                lp = gmm.log_prob(x[:, None, :])
                if log_probs is None:
                    log_probs = lp
                else:
                    log_probs = torch.cat((log_probs, lp))
            return log_probs

        # print("Complete fitting GMM, accuracy:", end=' ')
        # print(accuracy_score(
        #     y_val,
        #     torch.argmax(eval_embeddings(gaussians_model, x_val), dim=-1)
        # ))

        print(" measure the uncertainty")
        measure = logsumexp  # logsumexp or entropy
        ues = measure(eval_embeddings(gaussians_model, x_val)).numpy()
        ood_ues = measure(eval_embeddings(gaussians_model, x_ood)).numpy()

        # ues = measure(gmm_evaluate(
        #     model, gaussians_model, val_loader, device=device, num_classes=1000, storage_device=storage_device,
        # )[0])
        # ood_ues = measure(gmm_evaluate(
        #         model, gaussians_model, ood_loader, device='cuda', num_classes=1000, storage_device='cpu',
        # )[0])
        # # #
        described_plot(
            -ues, -ood_ues, args.ood_name, args.net, f'DDU'
        )
        # dump_ues(-ues, -ood_ues, f'ddu', 'imagenet', args.ood_name)


if __name__ == '__main__':
    main()
