'''
- conll.py
- This file handles formatting artificial data into the CONLL-2003 format for use in training NER techniques for VIDS
'''

# External imports
import errno
import random
import os
import re

# Internal imports
from src.core.configuration.datagen_conf import *
from src.utils.data_generation.tokenization import *

'''
----------format_conll----------
- Outputs the desired query set to a CONLL-2003 format file (for fine-tuning NER techniques)
-----Inputs-----
- schema_name - The name of the schema to use
- fold_num - The fold number to format
- batch_size - The desired batch size to truncate to
-----Output-----
- N/A - The folds, split into train/test/valid, in files in src/data
'''
def format_conll(schema_name, fold_num, batch_size=0):
    print("GENERATION: Formatting fold {} queries to CONLL-2003 standard".format(fold_num))
    # Initialize the splits and weights
    splits = ["test", "train", "valid"]
    weights = (10, 90 * (TRAIN_WEIGHT/100), 90 * (VALID_WEIGHT/100))
    # Initialize the file paths
    base_path = f"{ARTIFICIAL_DATA_LOC}/ner/{schema_name}/"
    file_reader = ""
    fold_path = ""
    if fold_num == "general":
        fold_path = base_path + "general/"
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    elif fold_num == "small":
        fold_path = base_path + "small/"
        file_reader = open(base_path + "raw/small.txt", "r", encoding="utf8")
    elif "lm" in fold_num:
        fold_path = base_path + "lm/"
        file_reader = open(f"{base_path}raw/{fold_num}", "r", encoding="utf8")
    else:
        fold_path = base_path + "fold{}/".format(fold_num)
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    # Make the directory if it doesn't exist
    if not os.path.exists(fold_path):
        os.makedirs(fold_path)
    # Initialize the file readers/writers for the given fold
    temp_files = {}
    file_writers = {}
    file_counters = {}

    for split in splits:
        temp_files[split] = open(fold_path+"{}-temp.txt".format(split), "w", encoding="utf8")
        file_writers[split] = open(fold_path+"{}.txt".format(split), "w", encoding="utf8")
        file_counters[split] = 0
        file_writers[split].write("-DOCSTART- -X- -X- O\n")

    # For each query in the file given, label it
    for query in file_reader:
        # Randomly assign a label to the query
        label = random.choices(splits, weights)[0]

        # Write the raw query to the appropriate file
        temp_files[label].write(query)
        file_counters[label] += 1
    # Close the reader
    file_reader.close()

    # Change to read mode on all temp files and set the counters to truncate the data
    for split in splits:
        # Change to read mode
        temp_files[split].close()
        temp_files[split] = open(fold_path + "{}-temp.txt".format(split), "r", encoding="utf8")
        # Set the counter to the appropriate number for truncated data
        if batch_size > 0:
            extra_examples = file_counters[split] % batch_size
            print("GENERATION: Fold {}, {} set has {} queries. Truncating to {} queries.".format(fold_num, split, file_counters[split], file_counters[split] - extra_examples))
            file_counters[split] -= extra_examples
        else:
            print("GENERATION: Fold {}, {} set has {} queries.".format(fold_num, split, file_counters[split]))
        
        # For every query in the file (until truncation), annotate it & write it
        for query in temp_files[split]:
            if (file_counters[split] > 0):
                # Remove the newline
                query = tokenize_query(query.replace("\n", ""))
                # Annotate the query
                sentence = conll_annotate_query(query, format=CONLL_FORMAT)
                # Write it to the appropriate file
                file_writers[split].write("\n")
                for line in sentence:
                    file_writers[split].write(line+"\n")
                # Subtract one from the counter
                file_counters[split] -= 1
        
        # Close the reader and writer
        temp_files[split].close()
        file_writers[split].close()

        # Delete the temporary file
        os.remove(fold_path + "{}-temp.txt".format(split))



'''
----------conll_annotate_query----------
- Changes the supplied query to a properly-formatted CONLL-2003 query
-----Inputs-----
- query - The tokenized query to annotate
- format - The format to use
-----Output-----
- conll_query - the CONLL-formatted query string
'''
def conll_annotate_query(query, format="iob"):
    conll_query = []
    # Declare key parts-of-speech
    verbs = ["PREDICT", "FIND", "HELP"]
    preps = ["FOR", "OVER"]
    adverbs = ["WHERE"]
    # Add variable to mark current phrase in sentence
    # Values are "NP", "VP", & "PP"
    current_phrase = "NP"
    # Previous token tracker
    prev_token = ""

    # For each token in the query, tag it appropriately
    for token in query:
        token_type = ""
        # Flag which part of the sentence we're on
        if token in verbs:
            token_type = "verb"
        elif token in preps:
            token_type = "prep"
        elif token in adverbs:
            token_type = "adverb"

        # If the token is a placeholder, remove the placeholder brackets
        if ("<" in token) or (">" in token):
            token = re.sub("[<>]", "", token)

        # If the token has a schema flag, tag it properly
        if "vids-" in token:
            # Based on which ID it has, tag the token appropriately
            token_type = token[5:8]
            token = token[9:]
            token = token[:-1]
            if token == "":
                # If the token is empty, pass over it
                pass
            elif format == "iob2":
                # Tag the beginning of every token phrase with "B"
                for word in token.split("_"):
                    if word == token.split("_")[0]:
                        conll_query.append("{} NN I-{} B-{}".format(word, current_phrase, token_type.upper()))
                    else:
                        conll_query.append("{} NN I-{} I-{}".format(word, current_phrase, token_type.upper()))
            else: # Assume iob format
                # If the same token type was used previously, add a delimiter to show the separation
                if (prev_token == token_type):
                    for word in token.split("_"):
                        if word == token.split("_")[0]:
                            conll_query.append("{} NN I-{} B-{}".format(word, current_phrase, token_type.upper()))
                        else:
                            conll_query.append("{} NN I-{} I-{}".format(word, current_phrase, token_type.upper()))
                else:
                    for word in token.split("_"):
                        conll_query.append("{} NN I-{} I-{}".format(word, current_phrase, token_type.upper()))

        # Tag the parts-of-speech properly
        elif token_type == "verb":
            # The token is a verb
            # If we were already in a verb phrase, add a delimiter
            if (current_phrase == "VP"):
                conll_query.append("{} VB B-VP O".format(token))
            else:
                conll_query.append("{} VB I-VP O".format(token))
            # Set the current phrase to a verb phrase
            current_phrase = "VP"
        elif token_type == "prep":
            # The token is a preposition
            # If we were already in a prep phrase, add a delimiter
            if (current_phrase == "PP"):
                conll_query.append("{} TO B-PP O".format(token))
            else:
                conll_query.append("{} TO I-PP O".format(token))
            # Set the current phrase to a prep phrase
            current_phrase = "PP"
        elif token_type == "adverb":
            # The token is an adverb
            # If we were already in an adverb phrase, add a delimiter
            if (current_phrase == "ADVP"):
                conll_query.append("{} TO B-ADVP O".format(token))
            else:
                conll_query.append("{} TO I-ADVP O".format(token))
            # Set the current phrase to a prep phrase
            current_phrase = "ADVP"
        elif token == "":
            pass
        else:
            # The token is a noun (for our purposes at least)
            conll_query.append("{} NN I-{} O".format(token, current_phrase))

        # Store the current token type as the previous token
        prev_token = token_type
    # Return the finished query
    return conll_query