#%%
from time import time
t0 = time()
import os

import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms

from tqdm import tqdm

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
import sys
# sys.path.append('experiments')
sys.path.append('.')

from image_uncertainty.uncertainty import metrics
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs
from imagenet import dataset_image, config_ood
from experiments.imagenet_discrete import load_model
from experiments.imagenet_discrete import normalize, parse_args, test_loaders, dump_ues

t1 = time()

print(f'Import time, {t1-t0}')


transform_void = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(240),
        transforms.ToTensor(),
    ]
)

transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(0.5),
    transforms.ColorJitter(0.02, 0.02, 0.02, 0.02),
    transforms.ToTensor(),
    normalize,
])


def tta_loaders(args):
    return test_loaders(args.data_folder, args.ood_folder, args.b, transform_void, subsample=True)


def validate(loader, model, gpu=True, repeats=1):
    model.eval()
    runs = []
    correct = []

    with torch.no_grad():
        for i, (images, target) in enumerate(tqdm(loader)):
            images = images.to(device)
            target = target.to(device)

            batch_predictions = []
            for _ in range(repeats):
                sub_batch = torch.stack([transform_train(img) for img in images]).to(device)
                batch_predictions.append(model(sub_batch))

            softmaxed = torch.softmax(torch.stack(batch_predictions), dim=-1)

            averaged = torch.mean(softmaxed, dim=0)
            _, pred = torch.softmax(averaged, dim=-1).topk(1)
            correct.extend((pred[:, 0] == target).cpu().tolist())

            runs.append(softmaxed.detach().cpu())
    return torch.cat(runs, dim=1).numpy(), correct


def main():
    args = parse_args()
    model = load_model(args.net)
    test_loader, ood_loader = tta_loaders(args)

    test_runs, correct = validate(test_loader, model, repeats=10)
    ood_runs, _ = validate(ood_loader, model, repeats=10)
    accuracy = round(sum(correct) / len(correct), 3)

    for acquisition in ['max_prob', 'entropy', 'std', 'bald']:
        print(acquisition)
        ues = np.array(mcd_ue(test_runs, acquisition))
        ood_ues = np.array(mcd_ue(ood_runs, acquisition))

        metrics.uncertainty_plot(
            ues, ood_ues,
            directory="./",
            title=f"Uncertainty ImageNet, OOD is part of birds, {acquisition}",
            accuracy=accuracy
        )

        dump_ues(ues, ood_ues, f'tta_{acquisition}', 'imagenet', args.ood_name)

    print("Run time", time() - t1)


if __name__ == '__main__':
    main()