'''
- script.py
- This file evaluates all given configurations, generating and saving the results under src/data/output
'''


# External imports
import json
from datetime import datetime
from alive_progress import alive_bar

# Internal imports
from src.core.interface import dialog, ranking, annotation
from src.core.petel.petel import petel
from src.utils.misc.schema import load_schema
import src.utils.misc.directory_ops as directory

# Configuration Parameters
CONF_LOC = "src/core/evaluation/global_config.json"
CONF_LIST_LOC = "src/core/evaluation/evaluation_config.json"
EASY_QUERIES_LOC = "src/data/test_data/easy_queries/"
HARD_QUERIES_LOC = "src/data/test_data/hard_queries/"
COMBINED_QUERIES_LOC = "src/data/test_data/combined/"
OUTPUT_LOC = "src/data/output/acl/exp"

def get_all_configurations():
    configurations = []
    SIMILARITY_FUNCTIONS = []
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
        # Get the necessary globals
        SCHEMA_NAMES = global_configuration["schema"]
        EMBEDDINGS = global_configuration["embeddings"]
        NER_MODELS = global_configuration["ner_models"]
        SIMILARITY_FUNCTIONS = global_configuration["similarity_functions"]
        extraction_task = global_configuration["task"]
        
        for model in NER_MODELS:
            for embedding in EMBEDDINGS:
                for function in SIMILARITY_FUNCTIONS:
                    for schema_name in SCHEMA_NAMES:
                        configurations.append({"schema": schema_name, "embedding": embedding,
                                              "model": model, "similarity_function": function, "task": extraction_task})
    
    return configurations

def get_custom_configurations():
    configurations = []
    SIMILARITY_FUNCTIONS = []
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
        # Get the necessary globals
        SCHEMA_NAMES = global_configuration["schema"]
        EMBEDDINGS = global_configuration["embeddings"]
        NER_MODELS = global_configuration["ner_models"]
        SIMILARITY_FUNCTIONS = global_configuration["similarity_functions"]
        model_mmr_configs = global_configuration["model_mmr_config"]
        for model in NER_MODELS:
            for embedding in EMBEDDINGS:
                for function in SIMILARITY_FUNCTIONS:
                    for schema_name in SCHEMA_NAMES:
                        mmr_config = model_mmr_configs[model]
                        task = mmr_config['task']
                        model_name = mmr_config['name']
                        model_path = f"models/lm/{task}/{schema_name}/{model_name}" if mmr_config == 'custom' else f"models/lm/{task}/combined/{model_name}"
                        model_path = f"{model_path}/{model_name}" if task == 'ner' else model_path
                        configurations.append({"schema": schema_name, 
                                                "embedding": embedding, 
                                                "model": model, 
                                                "similarity_function": function, 
                                                "task": task,
                                                "path": model_path})
    
    return configurations

def create_output_files():
    output_files = {}
    output_dir = {}
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
            # Get the necessary globals
        FEATURES = global_configuration["features"]
        SCHEMA_NAMES = global_configuration["schema"]
        EMBEDDINGS = global_configuration["embeddings"]
        SIMILARITY_FUNCTIONS = global_configuration["similarity_functions"]
        extraction_task = global_configuration["task"]
    
    for schema_name in SCHEMA_NAMES:
        output_dir[schema_name] = f"{OUTPUT_LOC}/{extraction_task}_{schema_name}/"
        directory.create(output_dir[schema_name])

    for feature in FEATURES:
        for function in SIMILARITY_FUNCTIONS:
            for SCHEMA_NAME in SCHEMA_NAMES:
                file_index = f"{feature}-{SCHEMA_NAME}-{function}"
                file_name_ext = f"{feature}-{function}-{datetime.now().strftime('%m-%d-%Y')}.csv"
                file_full_name = output_dir[SCHEMA_NAME] + file_name_ext
                output_files[file_index] = open(file_full_name, "w")
                for embedding in EMBEDDINGS:
                    output_files[file_index].write(", {}".format(embedding))
                # output_files[file_index].write("\nno Model")
                output_files[file_index].flush()
    return output_files

def create_custom_output_files():
    output_files = {}
    output_dir = {}
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
            # Get the necessary globals
        FEATURES = global_configuration["features"]
        SCHEMA_NAMES = global_configuration["schema"]
        EMBEDDINGS = global_configuration["embeddings"]
        SIMILARITY_FUNCTIONS = global_configuration["similarity_functions"]
        extraction_task = global_configuration["task"]
    
    for schema_name in SCHEMA_NAMES:
        output_dir[schema_name] = f"{OUTPUT_LOC}/{schema_name}/"
        directory.create(output_dir[schema_name])

    for feature in FEATURES:
        for function in SIMILARITY_FUNCTIONS:
            for SCHEMA_NAME in SCHEMA_NAMES:
                file_index = f"{feature}-{SCHEMA_NAME}-{function}"
                file_name_ext = f"{feature}-{function}-{datetime.now().strftime('%m-%d-%Y')}.csv"
                file_full_name = output_dir[SCHEMA_NAME] + file_name_ext
                output_files[file_index] = open(file_full_name, "w")
                for embedding in EMBEDDINGS:
                    output_files[file_index].write(", {}".format(embedding))
                # output_files[file_index].write("\nno Model")
                # output_files[file_index].flush()
    return output_files

def start(custom_configuration=True):
    # Use global_config to auto-generate the evaluation configuration file
    configurations = get_custom_configurations() if custom_configuration else get_all_configurations()
    for configuration in configurations:
        print(configuration)

    # Open output files for all results
    output_files = create_custom_output_files() if custom_configuration else create_output_files()

    # with open(CONF_LIST_LOC) as config_file:
    # Load the configurations
    #configurations = json.load(config_file)
    loaded_embedding = {}
    loaded_ner = {}
    prev_embedding = ""
    prev_model = ""
    for configuration in configurations:
        # Store the configuration parameters
        model = configuration["model"]
        embedding = configuration["embedding"]
        similarity_function = configuration["similarity_function"]
        schema = configuration["schema"]
        task = configuration['task']
        model_path = configuration['path']
        # Load the test cases and output file
        # test_cases = {"easy":open(EASY_QUERIES_LOC + schema + ".txt", "r"), "hard":open(HARD_QUERIES_LOC + schema + ".txt", "r")}
        # case_count = {"easy":open(EASY_QUERIES_LOC + schema + ".txt", "r"), "hard":open(HARD_QUERIES_LOC + schema + ".txt", "r")}
        test_cases = {"combined": open(COMBINED_QUERIES_LOC + schema + ".txt", "r")}
        case_count = {"combined": open(COMBINED_QUERIES_LOC + schema + ".txt", "r")}

        #output = open(OUTPUT_LOC + model + '_' + embedding + '_' + similarity_function + '_' + schema + '_' + datetime.now().strftime("%m-%d-%Y_%H:%M") + '.txt', "w")
        # Load the schema
        schema_file = open('src/data/test_data/schema/' + configuration["schema"] + '.schema', "r")
        loaded_schema = load_schema(schema_file.read())
        schema_file.close()

        # Print the current configuration
        print("CONFIGURATION: Testing " + model + " on " + embedding + " for " +
                schema + " with " + similarity_function + " distance metric")
        # Load the embedding
        embedding_same = False
        if (embedding == prev_embedding):
            print("LOADING: Embedding same as previous, no need to reload")
            embedding_same = True
        else:
            loaded_embedding.clear()
            loaded_embedding = ranking.load_embedding(
                configuration["embedding"], loaded_schema)
            prev_embedding = embedding
            print("LOADING: Embedding loaded")
        # Load the model technique
        ner_same = False
        if model == prev_model:
            print("LOADING: model Technique same as previous, no need to reload")
            ner_same = True
        else:
            if prev_model != "":
                loaded_ner.clear()
            loaded_ner = annotation.load_ner_model(
                model_name=configuration["model"], 
                task=task,
                model_path=model_path)
            prev_model = model
            print("LOADING: model Technique loaded")
            # Print the model technique to all csv files
            for key in output_files:
                output_files[key].write("\n" + model)
        # Begin testing the queries
        for query_type in test_cases:
            # Initialize the MRR storage variable
            reciprocal_rank = {"attribute": 0.0, "filter": 0.0, "filter_operation": 0.0,
                                "aggregator": 0.0, "entity": 0.0, "prediction_window": 0.0}
            # Announce the testing criteria
            print("TESTING: Running " + query_type + " queries")
            i = 1
            with alive_bar(sum(1 for line in case_count[query_type]), bar="smooth", spinner="classic") as bar:
                for test_case in test_cases[query_type]:
                    # print(test_case)
                    # Get the query
                    query = test_case.split("|")[0]
                    # Get the ground truth and initialize the counters for each feature
                    ground_truth = {}
                    counter = {}
                    l = 1
                    for key in reciprocal_rank:
                        gt = test_case.split('|')[l].upper().strip()
                        ground_truth[key] = gt.split(',')[0]
                        counter[key] = 0
                        l += 1
                    # Initialize the PeTEL expression
                    petel_exp = petel(loaded_schema)
                    # Test the query against VIDS
                    # print(petel_exp.rankings)
                    stop_loop = False
                    num_features_remaining = len(ground_truth)
                    while not stop_loop:
                        # Update the PeTEL expression and update the counter
                        petel_exp = dialog.run(
                            query, petel_exp, loaded_schema, configuration, loaded_embedding, loaded_ner, task=task)["expression"]
                        counter[petel_exp.active_feature] += 1

                        # If the guess matches the ground truth, mark the feature as complete
                        # print(petel_exp.rankings)

                        # print(petel_exp.active_feature, "active",
                        #     petel_exp.rankings[petel_exp.active_feature][0]["name"].replace("_"," ").upper(),
                        #     ground_truth[petel_exp.active_feature], 'ground truth')
                        predicted_attr = petel_exp.rankings[petel_exp.active_feature][0]["name"].replace("_", " ").upper()
                        g_truth = ground_truth[petel_exp.active_feature]
                        # print("prediction - ", predicted_attr, "| True lebel - ", g_truth)
                        switch_flag = predicted_attr == g_truth.upper()

                        if not switch_flag:
                            # print(petel_exp.active_feature, "active|",
                            #     predicted_attr, "predicted|",
                            #     g_truth, 'ground truth')
                            # print(petel_exp.rankings[petel_exp.active_feature])
                            # query = "no its " + g_truth.lower()
                            query = "no"
                        elif switch_flag:
                            #print("Picked {} in {} guesses".format(ground_truth[petel_exp.active_feature], counter[petel_exp.active_feature]))
                            query = "yes"
                            num_features_remaining -= 1
                            if num_features_remaining <= 0:
                                stop_loop = True

                    # Start calculating the MRR
                    for key in reciprocal_rank:
                        if key == 'entity':
                            continue
                        reciprocal_rank[key] += 1/counter[key]

                    i += 1
                    bar()
            # Finish calculating the MRR
            for key in reciprocal_rank:
                if key == 'entity':
                    continue
                reciprocal_rank[key] = reciprocal_rank[key] / i
                #output.write("Mean Reciprocal Rank - {} {}: {}\n".format(query_type, key, reciprocal_rank[key]))
                file_index = "{}-{}-{}".format(key, schema, similarity_function)
                output_files[file_index].write(f", {reciprocal_rank[key]}")
                output_files[file_index].flush()
            # Close the open files
            test_cases[query_type].close()
            case_count[query_type].close()
        # Close the output file
        # output.close()
    # Close all output files
    for key in output_files:
        output_files[key].close()


start()
