import argparse
import numpy as np
import pandas as pd
import torch
import yaml
import os
from my_code.data_utils.DataloaderLoaders import get_test_data
from my_code.testing_utils.TestingFunctions import accuracy_topk
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from my_code.models.model_code import get_model


if __name__ == "__main__":

    ####### possible arguments #######
    parser = argparse.ArgumentParser(description='Train a chosen model with different algorithms')
    parser.add_argument('--dataset-name', help='The dataset to use', type=str, default='mnist')
    parser.add_argument('--model-name', help='The model Name', required=True)
    parser.add_argument('--model-dir', help='The directory to the model saves', type=str, default='./outputs/models/')
    parser.add_argument('--test-method', help='The testing method', type=str, default='traditional')
    parser.add_argument('--data-dir', help='Directory for the data to be saved and loaded', 
                        type=str, default='./data/')
    parser.add_argument('-v', '--verbose', help='Whether to print information as the script runs',
                        action='store_true')
    parser.add_argument('--config-file', help='The config file containing the model parameters and training methods',
                            type=str, default='./config.yaml')
    parser.add_argument('--device', help='Device to run the models on.',
                            type=str, default='auto')
    parser.add_argument('--test-dir', help='The directory to save the model test results',
                         type=str, default='./outputs/test_results/')

    args = parser.parse_args()

    device = ('cuda' if torch.cuda.is_available() else 'cpu') if args.device == 'auto' else args.device

    # loading the test config
    test_config = yaml.load(open(args.config_file, 'r'), Loader=yaml.FullLoader)['testing-procedures'][args.test_method]
    if args.verbose: print('Testing config:', test_config)

    # find the saved models in the directory that match that model name given
    model_files = [model_file
                    for model_file in os.listdir(args.model_dir)
                    if args.model_name + '-' in model_file 
                    and 'state_dict' in model_file
                    and 'all_trained' in model_file
                    and args.dataset_name + '-' in model_file
                    ]

    # prints the models to be tested
    if args.verbose: 
        print('Testing the following models:')
        for model_file in model_files:
            print(model_file)

    # load model config and the data
    model_config = yaml.load(open(args.config_file, 'r'), Loader=yaml.FullLoader)[args.model_name]
    test_loader, test_targets = get_test_data(args, test_config=test_config)

    # test the models
    results = pd.DataFrame()
    for model_file in model_files:
        
        # load the model state dict from file
        model_sd = torch.load(args.model_dir + model_file)

        model_class = get_model(model_config)
        model_args = model_config['model_params']
        save_args = model_config['save_params']
        model_args['model_name']  = model_file.split('-seed')[0]
        model = model_class(
                    device=device, 
                    verbose=args.verbose, 
                    source_fit=model_config['train_params']['source_fit'],
                    **save_args, 
                    **model_args)
        
        # load the state dict into the model
        model.load_state_dict(model_sd)
        # ========== make predictions ==========
        output = model.predict(test_loader=test_loader)
        confidence, predictions = output.max(dim=1)
        if len(torch.unique(test_targets)) > 10:
            accuracy_top1, accuracy_top2, accuracy_top5 = accuracy_topk(output, test_targets, topk=(1,2,5))
            results_temp = {
                            'Run': [model_args['model_name']]*3,
                            'Metric': ['Accuracy', 'Top 2 Accuracy', 'Top 5 Accuracy'],
                            'Value': [accuracy_top1.item(), accuracy_top2.item(), accuracy_top5.item()],
                            }
        else:
            accuracy_top1, accuracy_top2 = accuracy_topk(output, test_targets, topk=(1,2,))
            results_temp = {
                            'Run': [model_args['model_name']]*2,
                            'Metric': ['Accuracy', 'Top 2 Accuracy'],
                            'Value': [accuracy_top1.item(), accuracy_top2.item()],
                            }

        if len(torch.unique(test_targets)) == 2:

            recall = recall_score(test_targets, predictions)
            precision = precision_score(test_targets, predictions)
            f1 = f1_score(test_targets, predictions)
            results_temp['Run'].extend([model_args['model_name']]*3)
            results_temp['Metric'].extend(['Recall', 'Precision', 'F1'])
            results_temp['Value'].extend([recall, precision, f1])

        # collate and save results
        results = pd.concat([results, pd.DataFrame(results_temp)])
    
    results.to_csv(args.test_dir + args.model_name + '-' + args.dataset_name + '-results.csv', index=False)