'''
- train_schema.py
- This file trains & evaluates all NER models for a given schema
'''


# External imports
import json
from math import comb
import os
import copy
from datetime import datetime

#from alive_progress import alive_bar

# Internal imports
from src.core.interface import data_generation as datagen, fine_tuning as tuning
from src.utils.misc.schema import load_schema
import src.utils.misc.directory_ops as directory
import json
import csv
import pandas as pd
from src.core.configuration import fine_tuning_conf as ft_conf

# Configuration Parameters
CONF_LOC = "src/core/evaluation/global_config.json"

def update_progress_file(task, schema_name, progress_parameters):
    progress_file_location = f"src/data/output/fine_tuning/logs/custom/{task}-{schema_name}-training-progress.json"

    # Open the file, dump the JSON, and close it
    progress_file = open(progress_file_location, "w")
    progress_file.write(json.dumps(progress_parameters))
    progress_file.close()


def main(
    hp_search=False,
    task='none',
    universal=False,
    evaluation_only=False,
    dataset_type="csr",
    output_dir="src/data/output/fine_tuning/training/universal"):
    with open(CONF_LOC) as config_file:
        # Load the configuration
        configuration = json.load(config_file)
        # Assign configuration variables
        SCHEMA_NAMES = configuration["schema"]
        MODEL_LIST = configuration["ner_models"]
        MODEL_NAMES = configuration["ner_model_names"]
        dataset_sizes = configuration["dataset_size"]
        # MODEL_LIST = MODEL_LIST[1:]
        # combined = configuration["combined"] if combined == 'none' else combined
        remake_data = False
        extraction_task = configuration["task"] if task == 'none'  else task
        for dataset_size in dataset_sizes:
            ft_conf.HP_SEARCH_DATA_SIZE = dataset_size
            override_cache_flag = False
            for SCHEMA_NAME in SCHEMA_NAMES:
            # Check to see if data has been generated for the schema. If not, generate it
                # if (not datagen.schema_data_exists(SCHEMA_NAME, extraction_task) or remake_data):
                #     datagen.generate_artificial_data(SCHEMA_NAME, folds=0, remake_data=remake_data)

                # Create the output directory if it doesn't already exist
                directory.create(f"{output_dir}")
                combined = SCHEMA_NAME == "combined" or universal

                # Check the progress file to see where to start from
                # SCHEMA_NAME = 'combined' if combined else SCHEMA_NAM
                
                

                # Open the output file
                output_file = ""
                tag =  "w"
                file_name = f"combined-{SCHEMA_NAME}-{extraction_task}-results.txt" if combined else f"{SCHEMA_NAME}-{extraction_task}-results.txt"
                
                output_file = open(f"{output_dir}/{file_name}", tag)
                
                # For every model in the configuration, run a hyperparameter search, and train the best candidate
                model_list =  MODEL_LIST
                for model in model_list:
                    
                    if not evaluation_only:
                    
                        json_path = 'src/data/output/fine_tuning/hp_logs/best_hp.json'
                        
                        with open(json_path) as json_file:
                            params = json.load(json_file)
                            
                            hyper_parameter = params[extraction_task][SCHEMA_NAME][model]
                            if extraction_task == 'ner':
                                tuning.fine_tune_model(
                                    SCHEMA_NAME, 
                                    model, 
                                    MODEL_NAMES[model], 
                                    hyper_parameter, 
                                    overwrite_cached_data=override_cache_flag, 
                                    combined_training=combined,
                                    dataset_type=dataset_type
                                    )
                            elif extraction_task == 'squad':
                                tuning.fine_tune_squad(
                                    SCHEMA_NAME, 
                                    model, 
                                    MODEL_NAMES[model], 
                                    hyper_parameter, 
                                    overwrite_cached_data=override_cache_flag, 
                                    combined_training=combined,
                                    dataset_type=dataset_type
                                    )
                        
                            
                            # Write the training progress to the file
                        

                    if extraction_task == 'ner':
                        result = tuning.evaluate_model(
                            SCHEMA_NAME, 
                            model, 
                            MODEL_NAMES[model], 
                            use_handcrafted_data=True, 
                            
                            combined_model=combined,
                            dataset_type=dataset_type
                            )
                        output_file.write("{}:\n\tf1: {}\n\tprecision: {}\n\trecall: {}\n".format(model, result["f1"], result["precision"], result["recall"]))
                        output_file.flush()
                    elif extraction_task == 'squad':
                        result = tuning.evaluate_squad(
                            SCHEMA_NAME, 
                            model, 
                            MODEL_NAMES[model], 
                            use_handcrafted_data=True, 
                            combined_model=combined,
                            dataset_type=dataset_type
                            )
                        output_file.write("{}:\n\tf1: {}\n\n".format(model, result["f1"]))
                        output_file.flush()

                    # Update the progress parameters, and re-write to the file
                output_file.close()

                # Delete the progress file, since it's no longer needed


# main()

# csv_path = 'src/data/output/fine_tuning/logs/hyperparameter-testing/'
# json_path = 'src/data/output/fine_tuning/logs/hyperparameter-testing/best_hp.json'
# save_best_hyper_parameter(csv_path, json_path, ['squad', 'ner'])