# -*- coding: utf-8 -*-
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path().absolute().parent))

from godin.datasets import get_dataset_train_val_test
from godin.load_models import get_model
from godin.parameters import get_args
from godin.utils_godin import show_stats, generate_scores
from godin.utils_godin import show_thresholding
from godin.utils_godin import get_model_accuracy


def main():
    args = get_args()
    device = args.gpu
    model_dir = args.model_dir
    data_dir = args.data_dir
    data_in = args.in_dataset
    batch_size = args.batch_size
    num_workers = args.num_workers
    percentile_threshold = args.percentile_threshold

    train_data, val_data, test_data = get_dataset_train_val_test(
        data_ind=data_in, batch_size=batch_size, data_dir=data_dir,
        num_workers=num_workers)

    deconf_net = get_model(
        model_dir=model_dir, data_in=data_in, device=device, val_set=val_data,
        percentile_threshold=percentile_threshold)
    deconf_net.eval()

    accuracy = get_model_accuracy(model=deconf_net, data_iter=val_data,
                                  CUDA_DEVICE=device)
    print('model accuracy: ', accuracy)

    # id_train_results = generate_scores(
    #     deconf_net, device, train_data, title='Training ID')
    # print('average id_train scores: ', np.average(id_train_results))

    id_val_results = generate_scores(
        deconf_net, device, val_data, title='Testing ID')
    show_stats(data=id_val_results)
    print('percentile threshold: ', percentile_threshold)
    print('score threshold: ', deconf_net.threshold)
    show_thresholding(threshold=deconf_net.threshold, results=id_val_results,
                      data_type='test')

    print('average id_test scores: ', np.average(id_val_results))
    np.save('id_test_results_{data_in}.npy', id_val_results)


if __name__ == '__main__':
    main()
