import argparse
from time import time
t0 = time()

import numpy as np
import torch
import torch.utils.data.distributed

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
from image_uncertainty.uncertainty import metrics
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs
from imagenet import dataset_image, config_ood
from imagenet.models import load_model
from experiments.imagenet_discrete import test_loaders

t1 = time()

print(f'Import time, {t1-t0}')
#%%

arch = "dropout"
data = "/gpfs/gpfs0/datasets/ImageNet/ILSVRC2012"
num_workers = 3
ood_list = config_ood.ood_list
ood_len = len(ood_list)



def mcd_runs_ue(loader, model, acquisitions=None, repeats=10, last_iter=int(1e9)):
    if acquisitions is None:
        acquisitions = ['max_prob', 'entropy', 'std', 'bald']

    correct = []
    results = {}
    for acquisition in acquisitions:
        results[acquisition] = []

    with torch.no_grad():
        for n_iter, (images, labels) in enumerate(loader):
            if last_iter is not None and last_iter == n_iter:
                break
            print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(loader)))
            images = images.cuda()
            labels = labels.cuda()
            preds = torch.stack([model(images) for _ in range(repeats)])

            softmaxed = torch.softmax(preds, dim=-1)
            averaged = torch.mean(softmaxed, dim=0)
            _, pred = torch.softmax(averaged, dim=-1).topk(1)
            correct.extend((pred[:, 0] == labels).cpu().tolist())

            batch_runs = softmaxed.detach().cpu().numpy()
            for acquisition in acquisitions:
                ues = mcd_ue(batch_runs, acquisition)
                results[acquisition].extend(ues)

        return results, correct



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--repeats', type=int, default=100)
    parser.add_argument('--dropout-rate', type=float, default=0.01)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--last-iter', type=int, default=int(1e10))

    args = parser.parse_args()
    torch.manual_seed(42)

    val_loader, ood_loader = test_loaders(args.batch_size)
    model = load_model(arch, dropout_rate=args.dropout_rate)
    model.eval()

    acquisitions = ['max_prob', 'entropy', 'std', 'bald']

    test_ues, correct = mcd_runs_ue(val_loader, model, acquisitions, args.repeats, args.last_iter)
    ood_ues, _ = mcd_runs_ue(ood_loader, model, acquisitions, args.repeats, args.last_iter)
    accuracy = sum(correct) / len(correct)

    for acquisition in acquisitions:
        print(acquisition)
        bench = metrics.uncertainty_plot(
            np.array(test_ues[acquisition]), np.array(ood_ues[acquisition]),
            directory="./",
            title=f"Uncertainty ImageNet, OOD is part of birds, {acquisition}",
            accuracy=accuracy
        )

if __name__ == '__main__':
    main()
