import sys

import torch

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import  default_weights
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test, misclassification_detection
)
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs


"""
Ensemble uncertainty estimation
It assumes you already trained the models and put them in checkpoint folder
"""

MODEL_SIZE = 5


class EnsembleWrapper:
    def __init__(self, models):
        self.models = models

    def __call__(self, x):
        preds = torch.stack([model(x) for model in self.models])
        averaged = torch.mean(torch.softmax(preds, dim=-1), dim=0)
        unsoftmaxed = torch.log(averaged)
        return unsoftmaxed


def main():
    args = get_eval_args()
    print(args.__dict__)

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    models = []
    for i in range(MODEL_SIZE):
        weights = default_weights(args.net, args.ood_name, args.data_seed, i)
        model = load_model(args.net, weights, args.gpu)
        models.append(model)

    id_runs, correct = mcd_runs(test_loader, models, args.gpu, ensemble=True)
    ood_runs, _ = mcd_runs(ood_loader, models, args.gpu, ensemble=True)
    accuracy = sum(correct) / len(correct)

    for acquisition in ['max_prob', 'entropy', 'std', 'bald']:
        ues = mcd_ue(id_runs, acquisition)
        misclassification_detection(correct, ues)
        ood_ues = mcd_ue(ood_runs, acquisition)

        described_plot(
            ues, ood_ues, args.ood_name, args.net, accuracy,
            f'ensemble ({acquisition}), T={len(models)}'
        )


if __name__ == '__main__':
    main()

