'''
- compare.py
- This file compares two given configurations on a given data set, and outputs the qualitative examples of a conflict
'''

configuration = {
    "conf1":{
        "model":"xlnet",
        "embedding":"use",
        "similarity_function":"cosine",
        "path": "models/ner/student_perf/xlnet/xlnet"
    },
    "conf2":{
        "model":"xlnet",
        "embedding":"use",
        "similarity_function":"cosine",
        "path": "models/ner/student_perf/xlnet/xlnet"
    },
    "schema":"student_perf",
    "feature":"filter_operation",
    "difficulty":"combined",
    "criteria":{
        "feature":"num_guesses",
        "comparison":"greater",
        "threshold":3
    }
}
# Features: annotation, num_guesses
# Comparisons: equal, greater, less


# External Imports
import json
from datetime import datetime
from alive_progress import alive_bar

# Internal imports
from src.core.interface import dialog, ranking, annotation
from src.core.petel.petel import petel
from src.utils.misc.schema import load_schema
import src.utils.misc.directory_ops as directory

# Configuration Parameters
QUERY_LOC = {
    "easy":"src/data/test_data/easy_queries/",
    "hard":"src/data/test_data/hard_queries/",
    "combined": "src/data/test_data/combined/"
}
OUTPUT_LOC = "src/data/output/comparisons/"

def start():
    # Load the schema
    schema_file = open('src/data/test_data/schema/' + configuration["schema"] + '.schema', 'r')
    loaded_schema = load_schema(schema_file.read())
    schema_file.close()

    # Print the current configuration
    print("CONFIGURATION: Comparing", configuration["conf1"]["model"], "on", configuration["conf1"]["embedding"],\
        "with", configuration["conf1"]["similarity_function"], "distance metric &", configuration["conf2"]["model"],\
        "on", configuration["conf2"]["embedding"], "with", configuration["conf2"]["similarity_function"], "distance metric")

    # Load the embeddings and model techniques
    loaded_embeddings = [
        ranking.load_embedding(configuration["conf1"]["embedding"], loaded_schema),
        ranking.load_embedding(configuration["conf2"]["embedding"], loaded_schema)
    ]
    print("LOADING: Loaded the embeddings:", configuration["conf1"]["embedding"], "&", configuration["conf2"]["embedding"])
    loaded_ner_model = [
        annotation.load_ner_model(configuration["conf1"]["model"], model_path=configuration["conf1"]["path"]),
        annotation.load_ner_model(configuration["conf2"]["model"], model_path=configuration["conf2"]["path"])
    ]
    print("LOADING: Loaded the NER techniques:", configuration["conf1"]["model"], "&", configuration["conf2"]["model"])
    
    # Load the queries and output file
    delimiter = "--------------------------------------------------"
    test_cases = open(QUERY_LOC[configuration["difficulty"]] + configuration["schema"] + ".txt", "r")
    case_count = open(QUERY_LOC[configuration["difficulty"]] + configuration["schema"] + ".txt", "r")
    output = open(OUTPUT_LOC + "compare_" + configuration["schema"] + "_" + configuration["feature"] + \
        "_" + configuration["difficulty"] + "_" + configuration["conf1"]["model"] +\
        "_" + configuration["conf1"]["embedding"] + "_" + configuration["conf1"]["similarity_function"] +\
        "_&_" + configuration["conf2"]["model"] + "_" + configuration["conf2"]["embedding"] + "_" +\
        configuration["conf2"]["similarity_function"] + "_" + datetime.now().strftime("%m-%d-%Y_%H-%M") + ".txt", "w")

    # Print the initial configuration to the output file
    output.write("Configuration 1: {} on {} with {} distance\n".format(configuration["conf1"]["model"],\
        configuration["conf1"]["embedding"], configuration["conf1"]["similarity_function"]))
    output.write("Configuration 2: {} on {} with {} distance\n".format(configuration["conf2"]["model"],\
        configuration["conf2"]["embedding"], configuration["conf2"]["similarity_function"]))
    output.write("Testing on: {} queries for the {} schema\n".format(configuration["difficulty"],\
        configuration["schema"]))
    output.write("Entries below have {} {} with a threshold of {}\n".format(configuration["criteria"]["comparison"],\
        configuration["criteria"]["feature"], str(configuration["criteria"]["threshold"])))
    output.write(delimiter + "\n" + delimiter + "\n\n")

    # Get the place to look for the ground truth based on the active schema feature
    truth_index = 0
    if (configuration["feature"] == "entity"):
        truth_index = 5
    elif (configuration["feature"] == "attribute"):
        truth_index = 1
    elif (configuration["feature"] == "filter"):
        truth_index = 2
    elif (configuration["feature"] == "filter_operation"):
        truth_index = 3
    elif (configuration["feature"] == "aggregator"):
        truth_index = 4
    else:
        print("Invalid schema feature detected. Aborting the test.")
        return

    # For each query, test it between the two configurations
    print("TESTING: Comparing the configurations on", configuration["difficulty"], configuration["schema"], "queries")
    reciprocal_rank = 0.0
    i = 1
    with alive_bar(sum(1 for line in case_count), bar="smooth", spinner="classic") as bar:
        for test_case in test_cases:
            # Get the query
            origin_query = test_case.split("|")[0]
            query = [origin_query, origin_query]
            # Get the ground truth and initialize the counters for each feature
            gt = test_case.split('|')[truth_index].upper().strip()
            ground_truth = gt.split(',')[0]
            counter = [0,0]

            # Initialize the PeTEL expressions
            petel_exp = [
                petel(loaded_schema),
                petel(loaded_schema)
            ]
            petel_exp[0].active_feature = configuration["feature"]
            petel_exp[1].active_feature = configuration["feature"]

            # Test the query against VIDS
            stop_loop = [False, False]
            qual_example = False
            annotations = ["", ""]
            while not (stop_loop[0] and stop_loop[1]):
                # If this is the first iteration, make an annotation attempt at both of the queries
                if (query[0] == origin_query):
                    annotations[0] = annotation.annotate_text(query[0], configuration["feature"], configuration["conf1"]["model"], loaded_ner_model[0])
                    annotations[1] = annotation.annotate_text(query[1], configuration["feature"], configuration["conf2"]["model"], loaded_ner_model[1])
                
                # For each configuration, test the query against VIDS if it hasn't already been guessed
                if not stop_loop[0]:
                    # Update the PeTEL expression
                    petel_exp[0] = dialog.run(query[0], petel_exp[0], loaded_schema, configuration["conf1"], loaded_embeddings[0], loaded_ner_model[0])["expression"]
                    # Update the query
                    query[0] = "no its " + ground_truth
                    # Update the counter
                    counter[0] += 1
                    # If the guess matches the ground truth, mark the configuration as complete
                    predicted_attr = petel_exp[0].rankings[configuration["feature"]][0]["name"].replace("_"," ").upper()
                    stop_loop[0] = predicted_attr == ground_truth.upper()
                    # print("conf1",petel_exp[0].rankings[configuration["feature"]][0]["name"].replace("_"," ").upper(), ground_truth)
                if not stop_loop[1]:
                    # Update the PeTEL expression
                    petel_exp[1] = dialog.run(query[1], petel_exp[1], loaded_schema, configuration["conf2"], loaded_embeddings[1], loaded_ner_model[1])["expression"]
                    # Update the query
                    query[1] = "no"
                    # Update the counter
                    counter[1] += 1
                    # If the guess matches the ground truth, mark the configuration as complete
                    predicted_attr = petel_exp[1].rankings[configuration["feature"]][0]["name"].replace("_"," ").upper()
                    stop_loop[1] = predicted_attr == ground_truth.upper()
                    # print("conf2", "predicted", predicted_attr, "Ground truth: ",  ground_truth, "count", counter[1])

                #if counter[0] > 7 or counter[1] > 7:
                    #break

            # If the query meets the configuration criteria, output it
            if (test_queries(annotations, counter)):
                output.write("Query (#{}): {}\n".format(i, origin_query))
                output.write("Ground Truth ({}): {}\n\n".format(configuration["feature"], ground_truth))
                output.write("Annotation 1: {}\n".format(annotations[0]))
                output.write("Annotation 2: {}\n\n".format(annotations[1]))
                output.write("Configuration 1 took {} attempt(s) to get the correct answer\n".format(counter[0]))
                output.write("Configuration 2 took {} attempt(s) to get the correct answer\n\n".format(counter[1]))
                output.write(delimiter + "\n\n")
            bar()
            i+=1

def test_queries(annotations, counters):
    # Get the comparison item
    return True
    comp_item = []
    threshold = configuration["criteria"]["threshold"]
    if configuration["criteria"]["feature"] == "num_guesses":
        comp_item = counters
        # Run the comparison
        if configuration["criteria"]["comparison"] == "equal":
            return comp_item[0] == comp_item[1]
        elif configuration["criteria"]["comparison"] == "greater":
            return comp_item[0] - threshold > comp_item[1]
        elif configuration["criteria"]["comparison"] == "less":
            return comp_item[0] < comp_item[1] - threshold
    elif configuration["criteria"]["feature"] == "annotation":
        comp_item = annotations
        # Run the comparison
        if configuration["criteria"]["comparison"] == "equal":
            if threshold == 0:
                return comp_item[0] == comp_item[1]
            else:
                same_query = False
                anot_len = [
                    len(comp_item[0].split()),
                    len(comp_item[1].split())
                ]
                if annot_len[0] > annot_len[1]:
                    temp = comp_item[0]
                    comp_item[0] = comp_item[1]
                    comp_item[1] = temp
                    temp = annot_len[0]
                    annot_len[0] = annot_len[1]
                    annot_len[1] = temp
                num_comp = 0
                if (annot_len[1] - annot_len[0]) > threshold:
                    return False
                else:
                    num_comp = (annot_len[1] - threshold) - annot_len[0] + 1
                for i in range(num_comp):
                    if not same_query:
                        comp_annot = comp_item[0].split()[i:]
                        #same_query = pass
                return same_query
        elif configuration["criteria"]["comparison"] == "greater":
            return len(comp_item[0].split()) - threshold > len(comp_item[1].split())
        elif configuration["criteria"]["comparison"] == "less":
            return len(comp_item[0].split()) < len(comp_item[1].split()) - threshold
    return False

start()

# Test the comparisons over a wide range of items (July 15th evaluation of XL-NET vs. RoBERTa)
'''timestamp = datetime.now().strftime("%m-%d-%Y")
output_loc_copy = OUTPUT_LOC
for embedding in ['fasttext','glove','word2vec']:
    for similarity in ['cosine', 'euclidean']:
        for threshold in range(1,10):
            OUTPUT_LOC = "{}{}_comparisons/{}/{}/".format(output_loc_copy, timestamp, embedding, similarity, threshold)
            # Make the directory
            directory.create(OUTPUT_LOC)
            # Modify the configuration
            configuration["criteria"]["threshold"] = threshold
            configuration["conf1"]["embedding"] = embedding
            configuration["conf2"]["embedding"] = embedding
            configuration["conf1"]["similarity_function"] = similarity
            configuration["conf2"]["similarity_function"] = similarity
            # Start the comparison
            start()'''