#%%
import os
from pathlib import Path
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

import sys
sys.path.append("..")
sys.path.append(".")

from image_uncertainty.imagenet import dataset_image
from image_uncertainty.imagenet.models import load_model
from image_uncertainty.cifar.cifar_evaluate import (
    described_plot, maxprob_ue, misclassification_detection
)
#%%

normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

inference_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]
)

def denormalize(image):
    img = image * 0.225 + 0.45
    return img


names = eval(open('imagenet/class_names.txt', 'r').read())



def show(image_tensor, label=None):
    if label:
        print(label)
        if type(label) == int:
            plt.title(names[label])
        else:
            plt.title(label)

    plt.imshow(np.moveaxis(denormalize(image_tensor).numpy(), 0, 2))
    plt.show()


def test_loaders(data_folder, ood_folder, batch_size, transforms=inference_transforms, subsample=False, shuffle=False):
    num_workers = 3

    def make_loader(directory, shuffle=False, subsample=False):
        dataset = dataset_image.ImageFolder(directory, transforms)
        if subsample:
            every_other = range(0, len(dataset), 25)
            dataset = torch.utils.data.Subset(dataset, every_other)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
        )

    val_loader = make_loader(data_folder, subsample=subsample, shuffle=shuffle)
    ood_loader = make_loader(ood_folder, shuffle=shuffle)

    return val_loader, ood_loader


def dump_ues(ues, ood_ues, method, dataset_name, ood_name):
    base_dir = Path('checkpoint') / dataset_name
    if not os.path.exists(base_dir):
        os.mkdir(base_dir)

    name = f"{method}_{ood_name}".replace(" ", "_")

    with open(base_dir / f"{name}_ues.npy", 'wb') as f:
        np.save(f, np.array(ues))

    with open(base_dir / f"{name}_ood_ues.npy", 'wb') as f:
        np.save(f, np.array(ood_ues))


def parse_args():
    zhores_ind = "/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012/val"
    zhores_ood = '/gpfs/gpfs0/k.fedyanin/space/imagenet_o'

    parser = ArgumentParser()
    parser.add_argument('--net', type=str, default='resnet50')
    parser.add_argument('-b', type=int, default=32)
    parser.add_argument('--data-folder', type=str, default=zhores_ind)
    parser.add_argument('--ood-folder', type=str, default=zhores_ood)
    parser.add_argument('--ood-name', type=str, default='imagenet-o')
    parser.add_argument(
        '--x-type', type=str, default='logits',
        help='logits or embeddings (from last layer)'
    )

    args = parser.parse_args()
    args.gpu = True
    return args

#%%
def main():
    args = parse_args()

    model = load_model(args.net)
    test_loader, ood_loader = test_loaders(args.data_folder, args.ood_folder, args.b, subsample=True)

    for acquisition in ['max_prob', 'entropy']:
        print(acquisition)
        ues, correct = maxprob_ue(test_loader, model, args.gpu, acquisition)
        ood_ues, _ = maxprob_ue(ood_loader, model, args.gpu, acquisition)
        accuracy = sum(correct) / len(correct)

        described_plot(
            ues, ood_ues, args.ood_name, args.net, accuracy,
            f'discrete {acquisition}', 'imagenet'
        )
        misclassification_detection(correct, ues['total'])
        dump_ues(ues, ood_ues, f'discrete_{acquisition}', 'imagenet', args.ood_name)


if __name__ == '__main__':
    main()
