import sys
import torch

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test,
    misclassification_detection
)
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs


def main():
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    REPEATS = 100
    DROPOUT_RATE = 0.05

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    model = load_model(args.net, args.weights, args.gpu, dropout_rate=DROPOUT_RATE, mc_dropout=True)
    model.eval()
    id_runs, correct = mcd_runs(test_loader, model, args.gpu, REPEATS)
    ood_runs, _ = mcd_runs(ood_loader, model, args.gpu, REPEATS)
    accuracy = sum(correct) / len(correct)

    for acquisition in ['max_prob', 'entropy', 'std', 'bald']:
        print(acquisition)
        ues = mcd_ue(id_runs, acquisition)
        misclassification_detection(correct, ues)
        ood_ues = mcd_ue(ood_runs, acquisition)

        described_plot(
            ues, ood_ues, args.ood_name, args.net, accuracy,
            f'dropout ({acquisition}), T={REPEATS}'
        )


if __name__ == '__main__':
    main()
