# -*- coding: utf-8 -*-
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import pathlib
import sys

import numpy as np

from calculate_metrics import calc_auroc, calc_tnr
from godin.utils_godin import show_stats, show_thresholding

sys.path.insert(0, str(pathlib.Path().absolute().parent))
from godin.datasets import get_datasets
from godin.load_models import get_model
from godin.parameters import get_args
from godin.utils_godin import generate_scores


def main():
    args = get_args()

    device = args.gpu

    model_dir = args.model_dir

    architecture = args.architecture
    similarity = args.similarity

    data_dir = args.data_dir
    data_ood = args.out_dataset
    data_ind = args.in_dataset
    batch_size = args.batch_size
    percentile_threshold = args.percentile_threshold

    noise_magnitudes = args.magnitudes

    # Create necessary directories
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    train_data, val_data, test_data, open_data, num_classes = get_datasets(
        data_dir=data_dir, data_ood=data_ood, batch_size=batch_size,
        data_ind=data_ind, num_workers=args.num_workers)

    deconf_net = get_model(
        architecture=architecture, similarity=similarity, val_set=val_data,
        num_classes=num_classes, model_dir=model_dir, device=device,
        percentile_threshold=percentile_threshold)

    deconf_net.eval()
    best_val_score = None
    best_auroc = None

    threshold = deconf_net.threshold
    if threshold is not None:
        threshold_str = f'_threshold_{percentile_threshold}'
    else:
        threshold_str = ''

    # score_functions = ['h', 'g', 'logit']
    score_functions = ['h']
    for score_func in score_functions:
        print(f'Score function: {score_func}')
        for noise_magnitude in noise_magnitudes:
            print(f'Noise magnitude {noise_magnitude:.5f}')

            id_val_results = generate_scores(
                deconf_net, device, val_data, title='Validating')
            show_stats(data=id_val_results)
            avg_val_score = np.average(id_val_results)
            np.save(
                f'./outputs/id_train_results_ind_{data_ind}_ood_{data_ood}{threshold_str}.npy',
                id_val_results)
            show_thresholding(threshold=threshold, results=id_val_results,
                              data_type='val')

            id_test_results = generate_scores(
                deconf_net, device, test_data, title='Testing ID')
            show_stats(data=id_test_results)
            np.save(
                f'./outputs/id_test_results_ind_{data_ind}_ood_{data_ood}{threshold_str}.npy',
                id_test_results)
            show_thresholding(threshold=threshold, results=id_test_results,
                              data_type='test')

            id_train_results = generate_scores(
                deconf_net, device, train_data, title='Training ID')
            show_stats(data=id_train_results)
            np.save(
                f'./outputs/id_train_results_ind_{data_ind}_ood_{data_ood}{threshold_str}.npy',
                id_train_results)
            show_thresholding(threshold=threshold, results=id_train_results,
                              data_type='train')

            ood_test_results = generate_scores(
                deconf_net, device, open_data, title='Testing OOD')
            show_stats(data=ood_test_results)
            np.save(
                f'./outputs/ood_test_results_ind_{data_ind}_ood_{data_ood}{threshold_str}.npy',
                ood_test_results)
            show_thresholding(threshold=threshold, results=ood_test_results,
                              data_type='ood')

            print('# of id_val_results: ', len(id_val_results))
            print('# of id_test_results: ', len(id_test_results))
            print('# of id_train_results: ', len(id_train_results))
            print('# of ood_test_results: ', len(ood_test_results))

            auroc = calc_auroc(1 / id_test_results, 1 / ood_test_results) * 100
            tnrATtpr95 = calc_tnr(1 / id_test_results, 1 / ood_test_results)
            print('AUROC:', auroc, 'TNR@TPR95:', tnrATtpr95)

            if best_auroc is None:
                best_auroc = auroc
            else:
                best_auroc = max(best_auroc, auroc)
            if best_val_score is None or avg_val_score > best_val_score:
                best_val_score = avg_val_score
                best_val_auroc = auroc
                best_tnr = tnrATtpr95

    print('best auroc: ', best_val_auroc, ' and tnr@tpr95 ', best_tnr)
    print('true best auroc:', best_auroc)


if __name__ == '__main__':
    main()
