import random
import sys

import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test, misclassification_detection
)
from image_uncertainty.uncertainty.methods import mcd_ue
from image_uncertainty.cifar import settings


transform_void = transforms.Compose([
    transforms.ToTensor()
])

transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(settings.CIFAR100_TRAIN_MEAN, settings.CIFAR100_TRAIN_STD)
])


def tta_runs(loader, model, gpu, repeats):
    runs = None
    correct = []

    with torch.no_grad():
        for n_iter, (images, labels) in enumerate(tqdm(loader)):
            # print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(loader)))
            if gpu:
                # images = images.cuda()
                labels = labels.cuda()
            preds = torch.stack([
                model(
                    torch.stack([transform_train(im) for im in images]).cuda()
                ) for _ in range(repeats)
            ])
            softmaxed = torch.softmax(preds, dim=-1)
            averaged = torch.mean(softmaxed, dim=0)
            _, pred = torch.softmax(averaged, dim=-1).topk(1)
            correct.extend((pred[:, 0] == labels).cpu().tolist())

            batch_runs = softmaxed.detach().cpu().numpy()
            if runs is None:
                runs = batch_runs
            else:
                runs = np.concatenate((runs, batch_runs), axis=1)

    return runs, correct


def patch_transform(loader):
    if isinstance(loader.dataset, torch.utils.data.ConcatDataset):
        for d in loader.dataset.datasets:
            d.transform = transform_void
    else:
        loader.dataset.transform = transform_void


def main():
    args = get_eval_args()
    print(args.__dict__)
    random.seed(args.data_seed)
    np.random.seed(args.data_seed)
    torch.manual_seed(args.data_seed)

    REPEATS = 20
    print(REPEATS, 'repeats tta')

    model = load_model(args.net, args.weights, args.gpu, mc_dropout=False)
    model.eval()

    test_loader = cifar_test(args.b, False, args.ood_name)
    patch_transform(test_loader)
    ood_loader = cifar_test(args.b, True, args.ood_name)
    patch_transform(ood_loader)

    id_runs, correct = tta_runs(test_loader, model, args.gpu, REPEATS)
    ood_runs, _ = tta_runs(ood_loader, model, args.gpu, REPEATS)

    accuracy = sum(correct) / len(correct)

    for acquisition in ['max_prob', 'entropy', 'bald']:
        print(acquisition)
        ues = mcd_ue(id_runs, acquisition)
        ood_ues = mcd_ue(ood_runs, acquisition)
        misclassification_detection(correct, ues)

        described_plot(
            ues, ood_ues, args.ood_name, args.net, accuracy,
            f'TTA ({acquisition}), T={REPEATS}'
        )


if __name__ == '__main__':
    main()
