import json
import os
import pandas as pd
from src.core.configuration import fine_tuning_conf as ft_conf
from src.core.interface import fine_tuning as tuning
# Configuration Parameters
CONF_LOC = "src/core/evaluation/global_config.json"

def search():
    
    hp_search = True

    search_for_hyperparam( task='ner')
    search_for_hyperparam(task='squad')
    search_for_hyperparam(task='ner', combined=True)
    search_for_hyperparam(task='squad', combined=True)

    csv_path = 'src/data/output/fine_tuning/hp_logs/'
    json_path = f'{csv_path}/best_hp.json'
    save_best_hyper_parameter(csv_path, json_path, ['squad', 'ner'])
        


def save_best_hyper_parameter(hp_search_dir, json_path, tasks):
    writer_file = open(json_path, 'w', encoding="utf8", newline='\n')
    best_hps = {}
    for task in tasks:
        best_hps[task] = {}
        csv_dir = os.path.join(hp_search_dir, task)
        for folder_name in os.listdir(csv_dir):
            schema_path = os.path.join(csv_dir, folder_name)
            for file_name in os.listdir(schema_path):
                csv_path = os.path.join(schema_path, file_name)
                
                dataset_name, model_name = folder_name, file_name.split(".")[0]
                if not dataset_name in best_hps[task]:
                    best_hps[task][dataset_name] = {}
                # best_hps[task][dataset_name][model_name] = {}

                df = pd.read_csv(csv_path)
                columns = df.columns
                rows = df.iloc[:, 0]
                best_f1 = 0.0
                best = {}
                for row in range(0, len(rows)):
                    for column in range(1, len(columns)):
                        current_f1 = df.iloc[row][column]
                        current_f1 = float(str(current_f1).strip())
                        if current_f1 >= best_f1:
                            best['weight_decay'] = rows[row]
                            best['learning_rate'] = float(columns[column].strip())
                            best['f1'] = current_f1
                            best_f1 = current_f1
                            print(current_f1, best)
                
                best_hps[task][dataset_name][model_name] = best

    json.dump(best_hps, writer_file)

def search_for_hyperparam(combined=False, task='ner', output_dir='hp_search'):
    with open(CONF_LOC) as config_file:
        # Load the configuration
        configuration = json.load(config_file)
        # Assign configuration variables
        SCHEMA_NAMES = configuration["schema"]
        MODEL_NAMES = configuration["ner_model_names"]
        MODEL_LIST = configuration["ner_models"]
        dataset_sizes = configuration["dataset_size"]
        extraction_task = task
        for dataset_size in dataset_sizes:
            ft_conf.HP_SEARCH_DATA_SIZE = dataset_size
            override_cache_flag = True
            for SCHEMA_NAME in SCHEMA_NAMES:
            # Check to see if data has been generated for the schema. If not, generate it
                combined = SCHEMA_NAME == "combined"
                for model in MODEL_LIST:
                    
                    candidate_parameters, f1_score = tuning.hyperparameter_search(SCHEMA_NAME, model, MODEL_NAMES[model], combined=combined, task = extraction_task, dataset_type="csr")
                    print(f"Model: {model}, Candidate Parameter: {candidate_parameters}, F1 score: {f1_score}")
                    
                   


search()