'''
- script.py
- This file evaluates all given configurations, generating and saving the results under src/data/output
'''


# External imports
import json
from datetime import datetime
from alive_progress import alive_bar

# Internal imports
from src.core.interface import dialog, ranking, annotation
from src.core.petel.petel import petel
from src.utils.misc.schema import load_schema
import src.utils.misc.directory_ops as directory

# Configuration Parameters
CONF_LOC = "src/core/evaluation/global_config.json"
COMBINED_QUERIES_LOC = "src/data/test_data/combined/"
OUTPUT_LOC = "src/data/output/acl/baseline"


def create_custom_output_files():
    output_files = {}
    output_dir = {}
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
            # Get the necessary globals
        FEATURES = global_configuration["features"]
        SCHEMA_NAMES = global_configuration["schema"]
        BASELINES = global_configuration["baseline"]
    
    for schema_name in SCHEMA_NAMES:
        output_dir[schema_name] = f"{OUTPUT_LOC}/{schema_name}/"
        directory.create(output_dir[schema_name])

    for feature in FEATURES:
            for SCHEMA_NAME in SCHEMA_NAMES:
                file_index = f"{feature}-{SCHEMA_NAME}-{function}"
                file_name_ext = f"{feature}-{function}-{datetime.now().strftime('%m-%d-%Y')}.csv"
                file_full_name = output_dir[SCHEMA_NAME] + file_name_ext
                output_files[file_index] = open(file_full_name, "w")
                output_files[file_index].flush()
    return output_files

def start(custom_configuration=True):
    # Open output files for all results
    output_files = create_custom_output_files()
    with open(CONF_LOC) as global_config:
        global_configuration = json.load(global_config)
            # Get the necessary globals
        FEATURES = global_configuration["features"]
        SCHEMA_NAMES = global_configuration["schema"]
        BASELINES = global_configuration["baseline"]
    # with open(CONF_LIST_LOC) as config_file:
    # Load the configurations
    #configurations = json.load(config_file)
    
    # Store the configuration parameters
    for baseline in BASELINES:
        for schema in SCHEMA_NAMES:
            # Load the test cases and output file
            # test_cases = {"easy":open(EASY_QUERIES_LOC + schema + ".txt", "r"), "hard":open(HARD_QUERIES_LOC + schema + ".txt", "r")}
            # case_count = {"easy":open(EASY_QUERIES_LOC + schema + ".txt", "r"), "hard":open(HARD_QUERIES_LOC + schema + ".txt", "r")}
            test_cases = {"combined": open(COMBINED_QUERIES_LOC + schema + ".txt", "r")}
            case_count = {"combined": open(COMBINED_QUERIES_LOC + schema + ".txt", "r")}

            #output = open(OUTPUT_LOC + model + '_' + embedding + '_' + similarity_function + '_' + schema + '_' + datetime.now().strftime("%m-%d-%Y_%H:%M") + '.txt', "w")
            # Load the schema
            schema_file = open(f'src/data/test_data/schema/{schema}.schema', "r")
            loaded_schema = load_schema(schema_file.read())
            schema_file.close()
            # Print the model technique to all csv files
            for key in output_files:
                output_files[key].write("\n" + baseline)
            # Begin testing the queries
            for query_type in test_cases:
                # Initialize the MRR storage variable
                reciprocal_rank = {"attribute": 0.0, "filter": 0.0, "filter_operation": 0.0,
                                    "aggregator": 0.0, "entity": 0.0, "prediction_window": 0.0}
                # Announce the testing criteria
                print("TESTING: Running " + query_type + " queries")
                i = 1
                with alive_bar(sum(1 for line in case_count[query_type]), bar="smooth", spinner="classic") as bar:
                    for test_case in test_cases[query_type]:
                        # print(test_case)
                        # Get the query
                        query = test_case.split("|")[0]
                        # Get the ground truth and initialize the counters for each feature
                        ground_truth = {}
                        counter = {}
                        l = 1
                        for key in reciprocal_rank:
                            gt = test_case.split('|')[l].upper().strip()
                            ground_truth[key] = gt.split(',')[0]
                            counter[key] = 0
                            l += 1
                        # Test the query against VIDS
                        # print(petel_exp.rankings)
                        stop_loop = False
                        num_features_remaining = len(ground_truth)
                        while not stop_loop:
                            # Update the PeTEL expression and update the counter
                            response = get_response()
                            counter += 1

                            # If the guess matches the ground truth, mark the feature as complete
                            # print(petel_exp.rankings)
                            user_input = input("> ")
                            # print(petel_exp.active_feature, "active",
                            #     petel_exp.rankings[petel_exp.active_feature][0]["name"].replace("_"," ").upper(),
                            #     ground_truth[petel_exp.active_feature], 'ground truth')
                            predicted_attr = petel_exp.rankings[petel_exp.active_feature][0]["name"].replace("_", " ").upper()
                            g_truth = ground_truth[petel_exp.active_feature]
                            # print("prediction - ", predicted_attr, "| True lebel - ", g_truth)
                            switch_flag = predicted_attr == g_truth.upper()

                            if not switch_flag:
                                # print(petel_exp.active_feature, "active|",
                                #     predicted_attr, "predicted|",
                                #     g_truth, 'ground truth')
                                # print(petel_exp.rankings[petel_exp.active_feature])
                                query = "no its " + g_truth.lower()
                                # query = "no"
                            elif switch_flag:
                                #print("Picked {} in {} guesses".format(ground_truth[petel_exp.active_feature], counter[petel_exp.active_feature]))
                                query = "yes"
                                num_features_remaining -= 1
                                if num_features_remaining <= 0:
                                    stop_loop = True

                        # Start calculating the MRR
                        for key in reciprocal_rank:
                            if key == 'entity':
                                continue
                            reciprocal_rank[key] += 1/counter[key]

                        i += 1
                        bar()
                # Finish calculating the MRR
                for key in reciprocal_rank:
                    if key == 'entity':
                        continue
                    reciprocal_rank[key] = reciprocal_rank[key] / i
                    #output.write("Mean Reciprocal Rank - {} {}: {}\n".format(query_type, key, reciprocal_rank[key]))
                    file_index = "{}-{}-{}".format(key, schema, similarity_function)
                    output_files[file_index].write(f", {reciprocal_rank[key]}")
                    output_files[file_index].flush()
                # Close the open files
                test_cases[query_type].close()
                case_count[query_type].close()
    # Close the output file
    # output.close()
    # Close all output files
    for key in output_files:
        output_files[key].close()


start()
