#%%
import os
import sys
from pathlib import Path
import torch
from tqdm import tqdm
import numpy as np
import torchvision
sys.path.append('.')

from imagenet.models import load_model
from experiments.imagenet_discrete import test_loaders, parse_args, dump_ues

from image_uncertainty.cifar.cifar_evaluate import described_plot
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs


#%%
MODEL_SIZE = 5
args = parse_args()

# args.data_folder = 'experiments/data/ILSVRC2012/val'
# args.ood_folder = 'experiments/data/imagenet_o'
args.checkpoint_directory = Path('experiments/checkpoint/imagenet_ensembles')


def load_models(base_dir):
    model_paths = sorted(os.listdir(base_dir))

    models = [
        load_model('resnet50', base_dir / p) for p in model_paths
    ]

    return models


models = load_models(args.checkpoint_directory)
print(models[0])

#%%
val_loader, ood_loader = test_loaders(args.data_folder, args.ood_folder, 32)
models = models[:MODEL_SIZE]
id_runs, correct = mcd_runs(val_loader, models, True, ensemble=True)
ood_runs, _ = mcd_runs(ood_loader, models, True, ensemble=True)
accuracy = sum(correct) / len(correct)

for acquisition in ['max_prob', 'entropy', 'std', 'bald']:
    print(acquisition)
    ues = mcd_ue(id_runs, acquisition)
    ood_ues = mcd_ue(ood_runs, acquisition)

    described_plot(
        ues, ood_ues, 'Imagenet ensembles', 'resnet50', accuracy,
        f'ensemble ({acquisition}), T={len(models)}'
    )

    dump_ues(ues, ood_ues, f'ensemble_{acquisition}', 'imagenet', args.ood_name)
