'''
- fine_tuning.py
- This file handles interfacing the fine_tuning functions of VIDS
'''

# External imports
import copy
import os
import json
from datetime import datetime
from statistics import mode

# Internal imports
from src.core.configuration.fine_tuning_conf import *
import src.utils.ft_pytorch.run_ner as finetuning
import src.utils.misc.directory_ops as directory
from src.core.configuration import squad_config
import src.utils.squad.run_qa_beam_search_no_trainer as ft_w_beam
import src.utils.squad.run_qa_no_trainer as ft_wo_beam

'''
----------fine_tune_model----------
- Fine-tunes a given model on a given schema using the supplied parameters
-----Inputs-----
- schema_name - The schema to train the model on
- model - The name of the model to use (bert, roberta, xlnet, etc)
- model_name - The name of the model checkpoint to use (bert-base-cased, etc.)
- hyperparameters - A dictionary of hyperparameters to use
- overwrite_cached_data - Whether or not to re-create the cached dataset (defaults to False)
- use_small_dataset - Whether or not to use the smaller general dataset in training (defaults to False)
-----Output-----
- N/A - This function writes the fine-tuned models to files in models/
'''
def fine_tune_model(schema_name, model, model_name, hyperparameters, overwrite_cached_data=False, dataset_type='small', combined_training=False, hyper_parameter_search=False, task='ner'):
    # Get a deep copy of the parameters in the config file
    training_parameters = copy.deepcopy(DEFAULT_TRAINING_PARAMETERS)

    # Set the model type, name, data/output directories, and hyperparameters
    if dataset_type == "small":
        if combined_training:
            training_parameters["data_dir"] = [
                "src/data/fine_tuning/ner/flight_delay/small/",
                "src/data/fine_tuning/ner/online_delivary/small/",
                "src/data/fine_tuning/ner/student_perf/small/",
            ]
            training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/combined/small/"
        else:
            training_parameters["data_dir"] = ["src/data/fine_tuning/ner/{}/small/".format(schema_name)]
            training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/{}/small/".format(schema_name)
        # training_parameters["output_dir"] =  "models/variance/ner/{}/".format(model)
        training_parameters["output_dir"] =  "models/combined/{}/".format(model) if combined_training else "models/{}/{}/".format(schema_name, model)
        # training_parameters["output_dir"] =  "models/combine/{}/".format(model)

        # training_parameters["num_train_epochs"] = 1.0
        # If the path doesn't exist, create it
        # directory.create("models/{}/temp/".format(schema_name))
        
    elif dataset_type == "csr":
        if combined_training:
            training_parameters["data_dir"] = [
                "src/data/fine_tuning/ner/flight_delay/lm/",
                "src/data/fine_tuning/ner/online_delivary/lm/",
                "src/data/fine_tuning/ner/student_perf/lm/",
            ]
            training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/combined/lm/"
        else:
            training_parameters["data_dir"] = [f"src/data/fine_tuning/ner/{schema_name}/lm/"]
            training_parameters["data_cache_dir"] = f"src/data/fine_tuning/ner/{schema_name}/lm/"

        training_parameters["output_dir"] =  f"models/lm/ner/combined/{model}/" if combined_training else f"models/lm/ner/{schema_name}/{model}/"

    else:
        if combined_training:
            training_parameters["data_dir"] = [
                "src/data/fine_tuning/ner/flight_delay/general/",
                "src/data/fine_tuning/ner/online_delivary/general/",
                "src/data/fine_tuning/ner/student_perf/general/",
            ]
            training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/combined/general/"
        else:
            training_parameters["data_dir"] = ["src/data/fine_tuning/ner/{}/general/".format(schema_name)]
            training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/{}/general/".format(schema_name)

        training_parameters["output_dir"] =  "models/combined/{}/".format(model) if combined_training else "models/{}/{}/".format(schema_name, model)

        # If the path doesn't exist, create it
        # directory.create("models/{}/{}/".format(schema_name, model))
    training_parameters["model_type"] = model
    training_parameters["model_name_or_path"] = model_name
    training_parameters["tokenizer_name"] = model_name
    training_parameters["learning_rate"] = hyperparameters["learning_rate"]
    training_parameters["weight_decay"] = hyperparameters["weight_decay"]
    training_parameters["is_hyper_parameter_search"] = True
    # Set whether to override the cached data
    
    training_parameters["overwrite_cache"] = overwrite_cached_data

    # Train the model
    finetuning.finetune(training_parameters)


def fine_tune_squad(schema_name, model, model_name, hyperparameters, dataset_type="small", hyper_parameter_search=False, combined_training=False, overwrite_cached_data=False, task='squad'):
    if not (task == 'squad'):
        return
    configs = squad_config.SQUADConfig().get_args()

    # Set the model type, name, data/output directories, and hyperparameters
    if dataset_type=="small":
        if combined_training:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/small/train.json",
                "src/data/fine_tuning/squad/online_delivary/small/train.json",
                "src/data/fine_tuning/squad/student_perf/small/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/squad/combined/small/"
        else:
            configs.data_dir = ["src/data/fine_tuning/squad/{}/small/train.json".format(schema_name)]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/small/".format(schema_name)
        # training_parameters["output_dir"] =  "models/variance/squad/{}/".format(model)
        configs.output_dir =  "models/squad/combined/{}/".format(model) if combined_training else "models/squad/{}/{}/".format(schema_name, model)
        # training_parameters["output_dir"] =  "models/combine/{}/".format(model)

        # training_parameters["num_train_epochs"] = 1.0
        # If the path doesn't exist, create it
        # directory.create("models/{}/temp/".format(schema_name))
        
    elif dataset_type == "csr":
        if combined_training:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/lm/train.json",
                "src/data/fine_tuning/squad/online_delivary/lm/train.json",
                "src/data/fine_tuning/squad/student_perf/lm/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/squad/combined/lm/"
        else:
            configs.data_dir = [f"src/data/fine_tuning/squad/{schema_name}/lm/train.json"]
            configs.cache_dir = f"src/data/fine_tuning/squad/{schema_name}/lm/"

        configs.output_dir =  f"models/lm/squad/combined/{model}/" if combined_training else f"models/lm/squad/{schema_name}/{model}/"

    else:
        if combined_training:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/general/train.json",
                "src/data/fine_tuning/squad/online_delivary/general/train.json",
                "src/data/fine_tuning/squad/student_perf/general/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/squad/combined/general/"
        else:
            configs.data_dir = ["src/data/fine_tuning/squad/{}/general/train.json".format(schema_name)]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/general/".format(schema_name)

        configs.output_dir =  "models/squad/combined/{}/".format(model) if combined_training else "models/squad/{}/{}/".format(schema_name, model)

        # If the path doesn't exist, create it
        # directory.create("models/{}/{}/".format(schema_name, model))
    configs.model_type = model
    configs.model_name_or_path = model_name
    configs.tokenizer_name = model_name
    configs.learning_rate = hyperparameters["learning_rate"]
    configs.weight_decay = hyperparameters["weight_decay"]
    configs.hyper_param_search = hyper_parameter_search
    squad_config.TRAIN_DATASET_PATH = configs.data_dir
    squad_config.EVAL_DATASET_PATH = [f"src/data/test_data/squad_format/{schema_name}.json"] if not combined_training else [
                                                                        "src/data/test_data/squad_format/flight_delay.json",
                                                                        "src/data/test_data/squad_format/online_delivary.json",
                                                                        "src/data/test_data/squad_format/student_perf.json",
                                                                    ]
    # # Set whether to override the cached data
    
    configs.overwrite_cache = False
    configs.per_device_train_batch_size = 32
    
    configs.num_train_epochs = 3

    if model == 'xlnet':
        ft_w_beam.finetune(configs)
    else:
        ft_wo_beam.finetune(configs)


'''
----------evaluate_model----------
- Evaluates a trained NER model for a given schema
-----Inputs-----
- schema_name - The schema to evaluate the model on
- model - The name of the model to use (bert, roberta, xlnet, etc)
- model_name - The name of the model checkpoint to use (bert-base-cased, etc.)
- use_small_dataset - Whether or not to use the smaller general dataset to evaluate (defaults to False)
-----Output-----
- results - The precision, recall, and f1 scores for the model
'''
def evaluate_model(schema_name, model, model_name, dataset_type="small", use_handcrafted_data=False, combined_model=False, task='ner'):
    # Get a deep copy of the parameters in the config file
    training_parameters = copy.deepcopy(DEFAULT_TRAINING_PARAMETERS)

    # Set the model type, name, data/output directories, and hyperparameters
    dict_name = ""
    if dataset_type=="small":
        # dict_name = "models/combine/{}".format(model)
        dict_name =  "models/combined/{}".format(model) if combined_model else "models/{}/{}".format(schema_name, model)
        # dict_name = "models/variance/ner/{}".format(model)
        training_parameters["data_dir"] = ["src/data/fine_tuning/ner/{}/small/".format(schema_name)]    
        training_parameters["data_cache_dir"] =  "src/data/fine_tuning/ner/combined/small/" if combined_model else "src/data/fine_tuning/ner/{}/small/".format(schema_name)

        training_parameters["output_dir"] = dict_name + "/"
        training_parameters["model_name_or_path"] =  dict_name + "/" + model
        training_parameters["eval_mode"] = "handcrafted"
        
    elif dataset_type == "csr":
        dict_name =  "models/lm/ner/combined/{}".format(model) if combined_model else "models/lm/ner/{}/{}".format(schema_name, model)
        # dict_name = "models/variance/ner/{}".format(model)
        training_parameters["data_dir"] = ["src/data/fine_tuning/ner/{}/small/".format(schema_name)]    
        training_parameters["data_cache_dir"] =  "src/data/fine_tuning/ner/combined/small/" if combined_model else "src/data/fine_tuning/ner/{}/small/".format(schema_name)

        training_parameters["output_dir"] = dict_name + "/"
        training_parameters["model_name_or_path"] =  dict_name + "/" + model
        training_parameters["eval_mode"] = "handcrafted"
    else:
        dict_name = "models/combined/{}".format(model) if combined_model  else "models/{}/{}".format(schema_name, model)
        # dict_name = "models/combine/{}".format(model)
        training_parameters["data_dir"] = ["src/data/fine_tuning/ner/{}/general/".format(schema_name)]
        training_parameters["data_cache_dir"] = "src/data/fine_tuning/ner/combined/general/" if combined_model else "src/data/fine_tuning/ner/{}/general/".format(schema_name)
        training_parameters["eval_mode"] = "handcrafted"
        if use_handcrafted_data:
            training_parameters["eval_mode"] = "handcrafted"
            training_parameters["data_dir"] = "src/data/fine_tuning/ner/{}/general/".format(schema_name)

        training_parameters["output_dir"] = dict_name + "/"
        training_parameters["model_name_or_path"] = dict_name + "/" + model
        # training_parameters["model_name_or_path"] = model_name

    training_parameters["model_type"] = model
    training_parameters["tokenizer_name"] = model_name
    training_parameters["do_predict"] = False

    print(training_parameters["model_name_or_path"])
    print("TRAINING: Evaluating the trained model on {} data".format(training_parameters["eval_mode"]))
    training_parameters["overwrite_cache"] = True
    # Evaluate the model, and return its results
    results = finetuning.evaluation(training_parameters)

    return {
        "loss":results["loss"],
        "precision":results["precision"],
        "recall":results["recall"],
        "f1":results["f1"]
    }




def evaluate_squad(schema_name, model, model_name, dataset_type="small", use_handcrafted_data=False, combined_model=False, task='squad'):
    if not (task=='squad'):
        return
    # Get a deep copy of the parameters in the config file
    configs = squad_config.SQUADConfig().get_args()

    # Set the model type, name, data/output directories, and hyperparameters
    dict_name = ""
    if dataset_type== "small":
        if combined_model:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/small/train.json",
                "src/data/fine_tuning/squad/online_delivary/small/train.json",
                "src/data/fine_tuning/squad/student_perf/small/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/small/".format(schema_name)
        else:
            configs.data_dir = ["src/data/fine_tuning/squad/{}/small/train.json".format(schema_name)]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/small/".format(schema_name)
        # dict_name = "models/combine/{}".format(model)
        dict_name =  f"models/squad/combined/{model}/" if combined_model else f"models/squad/{schema_name}/{model}/"
        
        # dict_name = "models/variance/ner/{}".format(model)
        # configs.data_dir = ["src/data/fine_tuning/squad/{}/small/".format(schema_name)]
        # configs.data_cache_dir =  "src/data/fine_tuning/squad/combined/small/" if combined_model else "src/data/fine_tuning/squad/{}/small/".format(schema_name)

        configs.output_dir = dict_name 
        configs.model_name_or_path =  dict_name 
        # configs.eval_mode = "handcrafted"
        
    elif dataset_type == "csr":
        if combined_model:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/lm/train.json",
                "src/data/fine_tuning/squad/online_delivary/lm/train.json",
                "src/data/fine_tuning/squad/student_perf/lm/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/ner/combined/lm/"
        else:
            configs.data_dir = [f"src/data/fine_tuning/squad/{schema_name}/lm/train.json"]
            configs.cache_dir = f"src/data/fine_tuning/squad/{schema_name}/lm/"

        dict_name =  f"models/lm/squad/combined/{model}/" if combined_model else f"models/lm/squad/{schema_name}/{model}/"
        configs.output_dir = dict_name 
        configs.model_name_or_path =  dict_name 

    else:

        if combined_model:
            configs.data_dir = [
                "src/data/fine_tuning/squad/flight_delay/general/train.json",
                "src/data/fine_tuning/squad/online_delivary/general/train.json",
                "src/data/fine_tuning/squad/student_perf/general/train.json",
            ]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/general/".format(schema_name)
        else:
            configs.data_dir = ["src/data/fine_tuning/squad/{}/general/train.json".format(schema_name)]
            configs.cache_dir = "src/data/fine_tuning/squad/{}/general/".format(schema_name)
        
        dict_name = f"models/squad/combined/{model}/" if combined_model else f"models/squad/{schema_name}/{model}/"
        # dict_name = "models/combine/{}".format(model)
        # configs.data_dir = ["src/data/fine_tuning/squad/{}/general/".format(schema_name)]
        # configs.data_cache_dir = "src/data/fine_tuning/squad/combined/general/" if combined_model else "src/data/fine_tuning/squad/{}/general/".format(schema_name)
        # configs.eval_mode = "handcrafted"
        # if use_handcrafted_data:
        #     configs.eval_mode = "handcrafted"
        #     configs.data_dir = "src/data/fine_tuning/squad/{}/general/".format(schema_name)

        configs.output_dir = dict_name
        configs.model_name_or_path = dict_name
        # training_parameters["model_name_or_path"] = model_name

    configs.model_type = model
    configs.tokenizer_name = model_name
    configs.do_predict = False

    
    # print("TRAINING: Evaluating the trained model on {} data".format(configs.eval_mode))
    configs.overwrite_cache = False
    squad_config.TRAIN_DATASET_PATH = configs.data_dir
    # ["src/data/fine_tuning/squad/{}/small/test.json".format(schema_name)]
    # squad_config.EVAL_DATASET_PATH = [f"src/data/test_data/squad_format/{schema_name}.json"] if not combined_model else [
    #     "src/data/test_data/squad_format/flight_delay.json",
    #     "src/data/test_data/squad_format/online_delivary.json",
    #     "src/data/test_data/squad_format/student_perf.json",
    # ]                               [f"src/data/test_data/squad_format/{schema_name}.json"
    # squad_config.EVAL_DATASET_PATH = [f"src/data/test_data/squad_format/{schema_name}.json"] if not combined_model else [
    #                                                                     "src/data/test_data/squad_format/flight_delay.json",
    #                                                                     "src/data/test_data/squad_format/online_delivary.json",
    #                                                                     "src/data/test_data/squad_format/student_perf.json",
    #                                                                 ]
    squad_config.EVAL_DATASET_PATH = [f"src/data/test_data/squad_format/{schema_name}.json"]
    # Evaluate the model, and return its results
    configs.cache_dir = f"src/data/fine_tuning/squad/{schema_name}/lm/"
    results = ft_w_beam.evaluate(configs) if model == 'xlnet' else ft_wo_beam.evaluate(configs)

    return {
        "f1":results["f1"]
    }


'''
----------hyperparameter_search----------
- Conducts a hyperparameter search for a given model & schema, returning the parameters that perform best
-----Inputs-----
- schema_name - The schema to train the model on
-----Output-----
- hyperparams - The top-performing hyperparameters for all the tests
'''
def hyperparameter_search(schema_name, model, model_name, combined=False, task='ner', dataset_type="small"):
    # Check to see if a progress file exists
    # progress_filename = "src/data/output/fine_tuning/logs/{}-{}-hyperparameter-search-progress.json".format(schema_name, model)
    progress_filename = f"src/data/output/fine_tuning/hp_logs/{task}/{schema_name}-{model}-hyperparameter-search-progress.json"
    in_progress = False
    # combined=True
    directory_path = f"src/data/output/fine_tuning/hp_logs/{task}/{schema_name}"
    progress_params = {}
    # If the progress file exists, load the parameters
    weight_decay_list = copy.deepcopy(WEIGHT_DECAYS)
    learning_rate_list = copy.deepcopy(LEARNING_RATES)
    if os.path.exists(progress_filename):
        print("TRAINING: An interrupted {} hyperparameter search was detected. Resuming the search.".format(model))
        progress_file = open(progress_filename)
        progress_params = json.load(progress_file)
        progress_file.close()
        in_progress = True
    else:
        progress_params["timestamp"] = datetime.now().strftime("%m-%d-%Y_%H_%M")
        progress_params["learning_rates"] = copy.deepcopy(learning_rate_list)
        progress_params["weight_decays"] = copy.deepcopy(weight_decay_list)
        progress_params["best_f1"] = 0
        progress_params["best_hyperparameters"] = {
            "weight_decay": WEIGHT_DECAYS[0],
            "learning_rate": LEARNING_RATES[0]
        }
        # Write the parameters to the file
        directory.create(directory_path)
        update_hyperparameter_progress(task, schema_name, model, progress_params)

    # Open the csv file for output
    directory.create(directory_path)
    file_path = f"{directory_path}/{model}.csv"
    if (not in_progress):
        output_file = open(file_path, "w")
        output_file.write('weight decay')
        for learning_rate in LEARNING_RATES:
            output_file.write(", " + str(learning_rate))
        output_file.write("\n")
    else:
        output_file = open(file_path, "a")
    print('TRAINING: Writing hyperparameter testing output to:', file_path)
    output_file.flush()

    # Initialize the result variable
    best_hyperparameters = {}
    best_f1 = 0

    # Over the hyperparameter search space, find the best model and return its hyperparameters
    data_rewritten = False
    weight_decay_list = copy.deepcopy(progress_params['weight_decays'])
    learning_rate_list = copy.deepcopy(progress_params['learning_rates'])
    for weight_decay in weight_decay_list:
        if (len(learning_rate_list) > 0 and learning_rate_list[0] == LEARNING_RATES[0]):
            output_file.write(str(weight_decay))
            output_file.flush()
            #in_progress = False
        for learning_rate in learning_rate_list:
            # Set the learning rate & weight decay
            hyperparameters = {
                "learning_rate":learning_rate,
                "weight_decay":weight_decay
            }
            print("TRAINING: Fine-tuning a(n) {} test model for the following hyperparameters: {}".format(model, hyperparameters))
            temp_result = {}
            # If the data hasn't yet been rewritten to cache, do that
            if task=='ner':
                if not data_rewritten:
                    # Train the candidate model using the supplied parameters
                    fine_tune_model(schema_name, model, model_name, hyperparameters, overwrite_cached_data=True, dataset_type=dataset_type, hyper_parameter_search=True, combined_training=combined, task=task)
                    data_rewritten = True
                else:
                    # Train the candidate model using the supplied parameters
                    fine_tune_model(schema_name, model, model_name, hyperparameters, dataset_type=dataset_type, hyper_parameter_search=True, combined_training=combined, task=task)
            elif task == 'squad':
                data_rewrite = not data_rewritten
                fine_tune_squad(schema_name, model, model_name, hyperparameters, dataset_type=dataset_type, hyper_parameter_search=True, combined_training=combined, task=task, overwrite_cached_data=data_rewrite)
                data_rewritten = True
            # Evaluate the model, and save its parameters
            if task == 'ner':
                temp_result = evaluate_model(schema_name, model, model_name, dataset_type=dataset_type, use_handcrafted_data=True, task=task)
            elif task == 'squad':
                temp_result = evaluate_squad(schema_name, model, model_name, dataset_type=dataset_type, use_handcrafted_data=True, task=task, combined_model=combined)
            # Record the results for this set of hyperparameters in the csv
            output_file.write(", " + str("{:0.4f}".format(temp_result["f1"])))
            output_file.flush()
            print('Fine-tuning output (L.R. {}, W.D. {}): {}'.format(learning_rate, weight_decay, temp_result["f1"]))
            # Store the best result, and continue
            if (temp_result["f1"] > progress_params["best_f1"]):
                progress_params["best_hyperparameters"] = {
                    "weight_decay":weight_decay,
                    "learning_rate":learning_rate
                }
                progress_params["best_f1"] = temp_result["f1"]

            # Remove the first learning rate from the list & write the progress parameters to the progress file
            progress_params["learning_rates"].pop(0)
            update_hyperparameter_progress(task, schema_name, model, progress_params)
        output_file.write("\n")
        output_file.flush()

        # Update the learning rate list again
        progress_params["learning_rates"] = copy.deepcopy(LEARNING_RATES)
        learning_rate_list = copy.deepcopy(LEARNING_RATES)
        progress_params["weight_decays"].pop(0)
        update_hyperparameter_progress(task, schema_name, model, progress_params)
    output_file.close()

    # Delete the hyperparameter progress tracking file, since it's no longer needed
    os.remove(progress_filename)

    return progress_params["best_hyperparameters"], progress_params["best_f1"]




'''
----------update_hyperparameter_progress----------
- Updates the hyperparameter search progress file
-----Inputs-----
- schema_name - The schema to train the model on
- model_name - The model that is being tested
- progress_parameters - The current state of progress in the hyperparameter search
-----Output-----
- N/A - The desired output is written to the file
'''
def update_hyperparameter_progress(task, schema_name, model_name, progress_parameters):
    progress_file_location = f"src/data/output/fine_tuning/hp_logs/{task}/{schema_name}-{model_name}-hyperparameter-search-progress.json"

    # Open the file, dump the JSON, and close it
    progress_file = open(progress_file_location, "w")
    progress_file.write(json.dumps(progress_parameters))
    progress_file.close()