'''
- squad.py
- This file handles formatting artificial data into the SQUAD-v2 format for use in training Q/A techniques for VIDS
'''

# External imports
import errno
import random
import os
import re
import json

# Internal imports
from src.core.configuration.datagen_conf import *
from src.utils.data_generation.tokenization import *

'''
----------format_squad----------
- Outputs the desired query set to a SQUAD-v2 format file (for fine-tuning NER techniques)
-----Inputs-----
- schema_name - The name of the schema to use
- fold_num - The fold number to format
- batch_size - The desired batch size to truncate to
-----Output-----
- N/A - The folds, split into train/test/valid, in files in src/data
'''
def format_squad(schema_name, fold_num, batch_size=0):
    q_count = QUESTION_NUMBER
    print("GENERATION: Formatting fold {} queries to SQUAD-v2 standard".format(fold_num))
    # Initialize the splits and weights
    splits = ["test", "train", "valid"]
    weights = (10, 90 * (TRAIN_WEIGHT/100), 90 * (VALID_WEIGHT/100))
    # Initialize the file paths
    base_path = f"{ARTIFICIAL_DATA_LOC}/squad/{schema_name}/"
    file_reader = ""
    fold_path = ""
    if fold_num == "general":
        fold_path = base_path + "general/"
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    elif fold_num == "small":
        fold_path = base_path + "small/"
        file_reader = open(base_path + "raw/small.txt", "r", encoding="utf8")
    elif "lm" in fold_num:
        fold_path = base_path + "lm/"
        file_reader = open(f"{base_path}raw/{fold_num}", "r", encoding="utf8")
    else:
        fold_path = base_path + "fold{}/".format(fold_num)
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    # Make the directory if it doesn't exist
    if not os.path.exists(fold_path):
        os.makedirs(fold_path)
    # Initialize the file readers/writers for the given fold
    temp_files = {}
    file_writers = {}
    file_counters = {}

    for split in splits:
        temp_files[split] = open(fold_path+"{}-temp.txt".format(split), "w", encoding="utf8")
        file_writers[split] = open(fold_path+"{}.json".format(split), "w", encoding="utf8")
        file_counters[split] = 0
        

    # For each query in the file given, label it
    for query in file_reader:
        # Randomly assign a label to the query
        label = random.choices(splits, weights)[0]

        # Write the raw query to the appropriate file
        temp_files[label].write(query)
        file_counters[label] += 1
    # Close the reader
    file_reader.close()

    # Change to read mode on all temp files and set the counters to truncate the data
    for split in splits:
        # Change to read mode
        temp_files[split].close()
        temp_files[split] = open(fold_path + "{}-temp.txt".format(split), "r", encoding="utf8")
        # Set the counter to the appropriate number for truncated data
        if batch_size > 0:
            extra_examples = file_counters[split] % batch_size
            print("GENERATION: Fold {}, {} set has {} queries. Truncating to {} queries.".format(fold_num, split, file_counters[split], file_counters[split] - extra_examples))
            file_counters[split] -= extra_examples
        else:
            print("GENERATION: Fold {}, {} set has {} queries.".format(fold_num, split, file_counters[split]))
        
        # For every query in the file (until truncation), annotate it & write it
        
        group = {}
        group["title"] = split
        paragraphs = []
        for query in temp_files[split]:
            if (file_counters[split] > 0):
                # Remove the newline
                query = query.replace("\n", "")
                # Annotate the query
                passage, q_count = squad_annotate_query(query, q_count)
                # Write it to the appropriate file
                paragraphs.append(passage)
                file_counters[split] -= 1
        group['paragraphs'] = paragraphs
        data={}
        data['data'] = [group]
        # with open("src/utils/data_generation/sample.json", "w") as outfile:
        json.dump(data, file_writers[split])
        # Close the reader and writer
        temp_files[split].close()
        file_writers[split].close()

        # Delete the temporary file
        os.remove(fold_path + "{}-temp.txt".format(split))

def squad_annotate_query(query, q_count):
    # Declare key parts-of-speech
    tokenized_query = tokenize_query(query)
    
    question_answer_list = []
    context_query = query.replace("vids-atr(", "").replace("vids-agg(", "").replace("vids-flt(", "").replace("vids-flo(", "").replace("vids-num(", "").replace("vids-prw(", "").replace("(","").replace(")", "").replace("_", " ").lower()

    
    passage ={}
    question_map = {
        "atr": "what is the target attribute?",
        "agg": "what is the target aggregator?",
        "flt": "what is the target filtering attribute?",
        "flo": "what is the target filtering operator?",
        "prw": "what is the prediction window?",
        "num": "what number is used?",
        "ent": "what is the target entity?"
    }
    
    

    # For each token in the query, tag it appropriately
    for token in tokenized_query:
        token_type = ""
        # Flag which part of the sentence we're on
        

        # If the token is a placeholder, remove the placeholder brackets
        if ("<" in token) or (">" in token):
            token = re.sub("[<>]", "", token)

        # If the token has a schema flag, tag it properly
        if "vids-" in token:
            # Based on which ID it has, tag the token appropriately
            token_type = token[5:8]
            extracted_token = token[9:]
            extracted_token = extracted_token[:-1]

            qa = {}
            question = question_map[token_type]
            answer_text = extracted_token.replace("_", " ").replace("(","").replace(")","").lower()
            answer_start, answer_end = [(match.start(), match.end()) for match in re.finditer(answer_text, context_query)][0]
            answer = {}
            answer['text'] = answer_text
            answer['answer_start'] = answer_start
            answer['answer_end'] = answer_end
            qa['question'] = question
            qa['answers'] = [answer]
            unique_id = hex(q_count)
            qa['id'] = str(unique_id)
            q_count +=1
            question_answer_list.append(qa)
        
    passage['context'] = context_query
    passage['qas'] = question_answer_list
    # with open("src/utils/data_generation/sample.json", "w") as outfile:
    #     json.dump(passage, outfile)
    # Return the finished query
    return passage, q_count