'''
- artificial_queries.py
- This file handles generating data based on automatically-generated artificial queries
'''

# External imports
import json
import random
import os
import re
from random import sample

from random import randint
from turtle import update
import nltk
from nltk.corpus import wordnet
from itertools import permutations
import shutil

from numpy import block

# Internal imports
from src.core.configuration.datagen_conf import *
from src.core.interface import annotation
from src.utils.misc.schema import load_schema
from src.utils.misc.io import *
from src.utils.data_generation import artificial_queries as generator, conll, gmb_formatter as gmb
'''
----------generate_templates----------
- Generates templates of the artificial user queries given the desired schema
-----Inputs-----
- type - "general" or "small" to specify how to generate the data
-----Output-----
- queries - The list of artificial query templates
'''
def generate_templates(type="general"):
    print("GENERATION: Creating artificial query templates")
    # Import the query stems
    queries_from_all_file = []
    i = 0
    for file_name in os.listdir(QUERY_STEMS_DIR):
        f_loc = os.path.join(QUERY_STEMS_DIR, file_name)
        # print("Reading query stem from:", f_loc)
        if not os.path.isfile(f_loc):
            print("file not found!!!")
            continue
        file = open(f_loc, 'r', encoding="utf8", newline="\n")
        query_stems = json.loads(file.read())
        file.close()
        queries = []

        # Initialize the frame scaler
        frame_scaler = 0
        if type == "small":
            frame_scaler = FRAME_SCALE
        else:
            frame_scaler = FRAME_SCALE

        frame_counter = 0
        drop_count = 0
        
        for query_block_name in query_stems:
            query_block = query_stems[query_block_name]
            loop_count = 1
            # if query_block['drop_probability'] > random.uniform(0, 1):
            #     drop_count += 1
            #     print("dropping", drop_count, query_block_name)
            #     continue
            if query_block['repeatable']:
                loop_count = random.randrange(1, MAX_REPEATING_BLOC_COUNT+1)
                # print(loop_count)
            temp_dict_q_list = []
            for index in range(0, loop_count):
                # print('Loop count', loop_count)
                for attr in query_block['values']:
                    val_list = query_block['values'][attr]
                    random.shuffle(val_list)
                    for val in val_list:
                        frame_counter = 0
  
                        if query_block_name == "prefix":
                            q = {}
                            q[attr] = {
                                "value": val,
                                "drop": query_block['drop_probability']
                            }
                            queries.append(q)
                        else:
                            for q in queries:
                                # print(queries, attr)
                                q[attr] = {
                                    "value": val,
                                    "drop": query_block['drop_probability']
                                }
                                # query = str(q[attr]["value"]) + str(val) + " "
                                temp_dict_q_list.append(q)
                                i += 1
                                if frame_counter > frame_scaler:
                                    frame_counter = 0
                                    break
                                else:
                                    frame_counter +=1
            # print('temp query len', len(temp_query_list))
            queries.extend(temp_dict_q_list)

        queries_from_all_file.extend(queries)
        
    # random.shuffle(queries_from_all_file)
    # print(i)
    query_strings = []
    dropped_keys = {}
    neg_sample_prob = 0.5
    for stem in queries_from_all_file:
        query = ""
        neg_sampling = ""
        for key in stem:
            q_str = stem[key]["value"]
            drop_prob = stem[key]["drop"]
            query = query + q_str + " "
            if drop_prob < random.uniform(0, 1):
                neg_sampling = neg_sampling + q_str + " "
            else:
                if key in dropped_keys:
                    drop_count = dropped_keys[key]["count"]
                    drop_count += 1
                    dropped_keys[key]["count"] = drop_count
                else:
                    dropped_keys[key] = {"count": 1}
        query = query.strip()
        neg_sampling = neg_sampling.strip()
        query_strings.append(query)
        if not neg_sample_prob < random.uniform(0, 1):
            query_strings.append(neg_sampling)


    # print_list(query_strings)
    # print_dict(dropped_keys)
    # print(len(queries_from_all_file))

    # temp_file = open('temp.txt', 'w')
    # print_list_file(temp_file, query_strings)
    # temp_file.close()
    print("GENERATION: {} artificial query stems created".format(len(query_strings)))
    return query_strings




'''
----------fill_templates----------
- Generates the artificial user queries given the desired schema and the templates
-----Inputs-----
- templates - The template queries to fill
- schema_name - The name of the schema to use
- type - "general" or "small" to specify how to fill the templates
-----Output-----
- N/A - The list of artificial query templates, written to a file
'''
def fill_templates(templates, schema_name, type="general", format="conll"):
    print("GENERATION: Filling artificial query stems from the {} schema".format(schema_name))
    # Import the schema
    schema = load_schema(open('src/data/test_data/schema/' + schema_name + '.schema').read())
    
    # Initialize the value markers (For use in formatting later)
    
    marker = {
        "VIDS-ENT": "entity",
        "VIDS-ATR": "attribute",
        "VIDS-FLT": "filter",
        "VIDS-FLO": "filter_operation",
        "VIDS-AGG": "aggregator",
        "VIDS-PRW": "prediction_window",
        "VIDS-NUM": "numbers"
    }

    # Initialize the permutation scaler
    perm_scaler = 0
    if type == "small":
        perm_scaler = SMALL_PERM_SCALE
    else:
        perm_scaler = PERM_SCALE

    # Initialize the fill scaler
    fill_scaler = 0
    if type == "small":
        fill_scaler = SMALL_FILL_SCALE
    else:
        fill_scaler = FILL_SCALE

    # For each template, fill the initial frame
    list_expanding = True
    queries = templates
    while(list_expanding):
        # print(len(queries))
        # For each template, fill in the initial frame
        new_templates = []
        temp_count = 0
        for template in queries:
            temp_count += 1
            #print(temp_count)
            # If there's a frame to fill, do it
            num_frames = template.count("VIDS-")
            # print(template, num_frames)
            fill_counter = 0
            if num_frames > 0:
                # Get the identity of the frame to fill
                frame_loc = template.find("VIDS-")
                frame_marker = template[frame_loc:frame_loc+8]# This +8 assumes all the markers are constant length
                template = template.replace(frame_marker, frame_marker.lower()+"({})", 1)
                # print(schema[marker[frame_marker]])
                # For every feature in the designated schema key, generate a new template
                feature_list = schema[marker[frame_marker]]
                random.shuffle(feature_list)
                for feature in feature_list:
                    if feature["name"].upper() == "NONE":
                        continue
                    filled_templates = []
                    new_templates.append(template.format(feature["name"].upper().replace(" ", "_")))
                    
                    # If we're filling in an aggregator, don't generate any permutations
                    # print(marker[frame_marker])
                    if (
                        marker[frame_marker] != "aggregator" 
                        and marker[frame_marker] != "filter_operation"
                        and marker[frame_marker] != "number"
                        and marker[frame_marker] != "prediction_window"
                        ):
                        # Append a new template for the feature description
                        description = feature["description"].strip()
                        name = feature['name'].strip()
                        
                        description = re.sub('\([^()]*\)', '', description)
                        name = re.sub('\([^()]*\)', '', name)

                        description = annotation.cleanse_input(description)
                        append_text = ""
                        if 0.3 < random.uniform(0, 1) and description != "":
                            append_text = description
                        else:
                            append_text = name

                        new_templates.append(template.format(append_text.upper().replace(" ", "_")))
                        # new_templates.append(template.format(name.upper().replace(" ", "_")))

                        # Append new templates for every permutation of the feature description
                        perms = calculate_permutations(description)
                        random.shuffle(perms)
                        for perm in perms[:5]:
                            filled_templates.append(template.format(perm.upper().replace(" ", "_")))
                        # Append new templates for some synonyms for each feature
                        for word in description.split():
                            if len(word) > 3:
                                for syn in wordnet.synsets(word):
                                    counter = 0
                                    for lemma in syn.lemmas():
                                        if counter > NUM_SYNONYMS:
                                            break
                                        elif (lemma.name() != word):
                                            counter += 1
                                            new_description = description.replace(word, lemma.name()).upper().replace(" ", "_")
                                            filled_templates.append(template.format(new_description))
                    # Add a few samples according to PERM_SCALE
                    random.shuffle(filled_templates)
                    #print(filled_templates)
                    if 0.50 < random.uniform(0, 1):
                        if perm_scaler > 0:
                            new_templates.extend(filled_templates[:perm_scaler])
                        else:
                            new_templates.extend(filled_templates)
                    # Increment the counter
                    fill_counter += 1
                    # If the counter exceeds the threshold set by FILL_SCALE, exit the loop
                    if (fill_scaler > 0) and (fill_counter >= fill_scaler):
                        break
            else:
                new_templates.append(template)
           

        # If the list didn't expand, exit the loop
        if (len(queries) >= len(new_templates)):
            list_expanding = False
        else:
            new_templates = list(set(new_templates))
            queries = new_templates

        # break
    # Shuffle the queries
    # print_list(schema['filter'])
    print("GENERATION: {} artificial queries created by filling the templates".format(len(queries)))

    # temp_file = open('temp.txt', 'w')
    # print_list_file(temp_file, queries)
    # temp_file.close()
    sample_prob = 0.3 if type == "small" else 0.6
    queries = sample_percentage(queries, sample_prob)
    
    # Write the shuffled, unformatted queries to a file
    base_path = f"{ARTIFICIAL_DATA_LOC}/{format}/{schema_name}/"
    # shutil.rmtree(base_path)
    if not os.path.exists(base_path + "raw/"):
        os.makedirs(base_path + "raw/")
    if type == "small":
        write_file = open(base_path + "raw/small.txt", "w", encoding="utf8")
    else:
        write_file = open(base_path + "raw/general.txt", "w", encoding="utf8")
    for query in queries:
        # print(query)
        write_file.write(query+"\n")
    write_file.close()


def fill_templates_ongoing(templates, schema_name, type="general"):
    print("GENERATION: Filling artificial query stems from the {} schema".format(schema_name))
    # Import the schema
    schema = load_schema(open('src/data/test_data/schema/' + schema_name + '.schema').read())
    
    # Initialize the value markers (For use in formatting later)
    
    marker = {
        "VIDS-ENT": "entity",
        "VIDS-ATR": "attribute",
        "VIDS-FLT": "filter",
        "VIDS-FLO": "filter_operation",
        "VIDS-AGG": "aggregator",
        "VIDS-PRW": "prediction_window",
        "VIDS-NUM": "numbers"
    }

    # Initialize the permutation scaler
    perm_scaler = 0
    if type == "small":
        perm_scaler = SMALL_PERM_SCALE
    else:
        perm_scaler = PERM_SCALE

    # Initialize the fill scaler
    fill_scaler = 0
    if type == "small":
        fill_scaler = SMALL_FILL_SCALE
    else:
        fill_scaler = FILL_SCALE
    
    sample_size = 1

    # For each template, fill the initial frame
    list_expanding = True
    queries = templates[:1]
    while(list_expanding):
        print(len(queries))
        # For each template, fill in the initial frame
        new_templates = []
        temp_count = 0
        for template in queries:
            temp_count += 1
            # print(temp_count, template)
            # If there's a frame to fill, do it
            num_frames = template.count("VIDS-")
            # num_false_frames = template.count("vids-")
            # print(num_frames)
            fill_counter = 0
            if num_frames > 0:
                # Get the identity of the frame to fill
                working_template = template
                locations = [m.start() for m in re.finditer('VIDS-', working_template)]
                # print(locations, working_template)
                number_of_frame = len(locations)
                
                if number_of_frame >= sample_size:
                    sample_select = sample(locations, sample_size)
                else:
                    sample_select = sample(locations, 1)
                # frame_loc = template.find("VIDS-")
                frame_markers = []
                for select_samp in sample_select:
                    # print(select_samp, working_template[select_samp])
                    frame_marker = working_template[select_samp:select_samp+8]# This +8 assumes all the markers are constant length
                    frame_markers.append(frame_marker)
                    template = working_template.replace(frame_marker, frame_marker.lower()+"({})", 1)
                # print_list(frame_markers)
                # print_list(sample_select)
                # print(schema[marker[frame_marker]])
                # For every feature in the designated schema key, generate a new template
                # i = 0
                # frame_marker = frame_markers[0]
                # for feature in schema[marker[frame_marker]]:
                for index in range(fill_scaler):
                    frame_update = template
                    for frame_marker in frame_markers:

                        num_of_features = len(schema[marker[frame_marker]])
                        feature_index = randint(0, num_of_features - 1) 
                                        
                        feature = schema[marker[frame_marker]][int(feature_index)]
                        while feature["name"].upper() == "NONE":
                            feature_index = randint(0, num_of_features - 1)                
                            feature = schema[marker[frame_marker]][int(feature_index)]
                        # if marker[frame_marker] != "numbers":
                        # print(template, feature)
                        replacable_text = frame_marker.lower()+"({})"
                        new_text = frame_marker.lower()+"({})".format(feature["name"])
                        new_text = new_text.upper().replace(" ", "_")
                        frame_update = frame_update.replace(replacable_text, new_text)
                        # print(template)
                    # if not any(frame_update == string for string in new_templates):
                        # filled_templates.append(template_fill)
                    new_templates.append(frame_update)

                    frame_update = template
                    filled_templates = []
                    update_filled_temp = True
                    filled_templates.append(template)
                    for frame_marker in frame_markers:
                        if (marker[frame_marker] != "aggregator" and marker[frame_marker] != "filter_operation"
                            and marker[frame_marker] != "prediction_window" and marker[frame_marker] != "numbers"):
                            # Append a new template for the feature description
                            description = feature["description"]
                            if ("(" in description):
                                description = re.sub('\([^()]*\)', '', description)
                            description = annotation.cleanse_input(description)
                            description = description.upper().replace(" ", "_")
                            replacable_text = frame_marker.lower()+"({})"
                            new_text = frame_marker.lower()+"({})".format(description)
                            frame_update = template.replace(replacable_text, new_text)
                            # Append new templates for every permutation of the feature description
                            perms = calculate_permutations(description)
                            random.shuffle(perms)
                            
                            
                            print(len(filled_templates))
                            i = 0
                            while i < len(filled_templates):
                                filled = filled_templates[i]
                                i += 1
                                temp_filled_temp = filled_templates
                                for perm in perms[:NUM_PERM]:
                                    perm = perm.upper().replace(" ", "_")
                                    replacable_text = frame_marker.lower()+"({})"
                                    new_text = frame_marker.lower()+"({})".format(perm)
                                    template_fill = filled.replace(replacable_text, new_text)
                                    # if not any(template_fill == string for string in filled_templates):
                                    if update_filled_temp:
                                        # print(filled_templates)
                                        temp_filled_temp.append(template_fill)
                                        
                                # Append new templates for some synonyms for each feature
                                for word in description.split():
                                    if len(word) > 3:
                                        for syn in wordnet.synsets(word):
                                            counter = 0
                                            for lemma in syn.lemmas():
                                                if counter > NUM_SYNONYMS:
                                                    break
                                                elif (lemma.name() != word):
                                                    counter += 1
                                                    new_description = description.replace(word, lemma.name()).upper().replace(" ", "_")
                                                    new_text = frame_marker.lower()+"({})".format(new_description)
                                                    template_fill = filled.replace(replacable_text, new_text)
                                                    # if not any(template_fill == string for string in filled_templates):
                                                    if update_filled_temp:
                                                        print("update lemma")
                                                        temp_filled_temp.append(template_fill)
                                                    # else:    
                                                    #     print("skipped", template_fill)
                                                    # filled_templates.append(template.format(new_description))
                                if update_filled_temp:
                                    filled_templates.extend(temp_filled_temp)
                                    update_filled_temp = False
                                
                        # Add a few samples according to PERM_SCALE
                    random.shuffle(filled_templates)
                    # print(filled_templates)
                    # if perm_scaler > 0:
                    #     new_templates.extend(filled_templates[:perm_scaler])
                    # else:
                    #     new_templates.extend(filled_templates)
                    # Increment the counter
                    new_templates.append(frame_update)
                    new_templates = list(set(new_templates))
                    fill_counter += 1
                    # If the counter exceeds the threshold set by FILL_SCALE, exit the loop
                    if (fill_scaler > 0) and (fill_counter >= fill_scaler):
                        break
                # print(i)    
            else:
                new_templates.append(template)

        # # If the list didn't expand, exit the loop
        if (len(queries) >= len(new_templates)):
            list_expanding = False
        else:
            queries = new_templates

        # break
    # Shuffle the queries
    print_list(queries[:20])
    print("GENERATION: {} artificial queries created by filling the templates".format(len(queries)))
    
    # Write the shuffled, unformatted queries to a file
    base_path = f"{ARTIFICIAL_DATA_LOC}/{format}/{schema_name}/"
    if not os.path.exists(base_path + "raw/"):
        os.makedirs(base_path + "raw/")
    if type == "small":
        write_file = open(base_path + "raw/small.txt", "w", encoding="utf8")
    else:
        write_file = open(base_path + "raw/general.txt", "w", encoding="utf8")
    for query in queries:
        # print(query)
        write_file.write(query+"\n")
    write_file.close()



'''(LOOK INTO THIS FUNCTION & CLEAN IT UP)
This function has yet to be explained.

This explanation needs to be written,
like it has for lieterally every other
function in this project
'''
def calculate_permutations(sentence):
    # Stores all words in the sentence
    lis = list(sentence.split())
    sentences = []
    # Stores all possible permuations
    # of words in this list
    permute = permutations(lis)
    # Iterate over all permutations
    for i in permute:
        temp_sentence = ''
    # Convert the current
    # permutation into a list
    permutelist = list(i)
    # Print the words in the
    # list separated by spaces
    for j in permutelist:
        temp_sentence = temp_sentence + j + ' '
    sentences.append(temp_sentence)
    random.shuffle(sentences)
    if len(sentences) < NUM_PERM:
        return sentences
    return sentences[:NUM_PERM]

# generate_templates()

def sample_percentage(query_templates, percentage):
    random.shuffle(query_templates);
    sampled_list = random.sample(query_templates, int(len(query_templates)*percentage))
    print("\nFrom the template list of {} we have sampled {}\n".format(len(query_templates), len(sampled_list)))
    return sampled_list

# query_templates = generate_templates(type="general")
        
#         # Fill the general templates & write them to a file
# fill_templates(query_templates, "online_delivary", type="general")
# # # Format the general queries
# conll.format_conll("online_delivary", 1, 32)