'''
- csr.py
- This file handles formatting artificial data into the common gen format for use in training Q/A techniques for VIDS
'''

# External imports
import errno
import random
import os
import re
import json

# Internal imports
from src.core.configuration.datagen_conf import *
from src.utils.data_generation.tokenization import *

'''
----------format_squad----------
- Outputs the desired query set to a Common Gen format file (for fine-tuning NER techniques)
-----Inputs-----
- schema_name - The name of the schema to use
- fold_num - The fold number to format
- batch_size - The desired batch size to truncate to
-----Output-----
- N/A - The folds, split into train/test/valid, in files in src/data
'''
def format_csr(schema_name, fold_num, batch_size=0):
    q_count = QUESTION_NUMBER
    print("GENERATION: Formatting fold {} queries to Common Gen standard".format(fold_num))
    # Initialize the splits and weights
    splits = ["test", "train", "valid"]
    
    weights = (10, 90 * (TRAIN_WEIGHT/100), 90 * (VALID_WEIGHT/100))
    # Initialize the file paths
    base_path = f"{ARTIFICIAL_DATA_LOC}/csr/{schema_name}/"
    file_reader = ""
    fold_path = ""
    if fold_num == "general":
        fold_path = base_path + "general/"
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    elif fold_num == "small":
        fold_path = base_path + "small/"
        file_reader = open(base_path + "raw/small.txt", "r", encoding="utf8")
    else:
        fold_path = base_path + "fold{}/".format(fold_num)
        file_reader = open(base_path + "raw/general.txt", "r", encoding="utf8")
    # Make the directory if it doesn't exist
    if not os.path.exists(fold_path):
        os.makedirs(fold_path)
    # Initialize the file readers/writers for the given fold
    temp_files = {}
    file_writers = {}
    file_counters = {}

    for split in splits:
        temp_files[split] = open(fold_path+"{}-temp.txt".format(split), "w", encoding="utf8")
        file_writers[split] = open(fold_path+"{}.jsonl".format(split), "w", encoding="utf8")
        file_counters[split] = 0
        

    # For each query in the file given, label it
    for query in file_reader:
        # Randomly assign a label to the query
        label = random.choices(splits, weights)[0]

        # Write the raw query to the appropriate file
        temp_files[label].write(query)
        file_counters[label] += 1
    # Close the reader
    file_reader.close()

    data_for_ret = []
    # Change to read mode on all temp files and set the counters to truncate the data
    # splits = ["test"]
    for split in splits:
        # Change to read mode
        temp_files[split].close()
        temp_files[split] = open(fold_path + "{}-temp.txt".format(split), "r", encoding="utf8")
        # Set the counter to the appropriate number for truncated data
        if batch_size > 0:
            extra_examples = file_counters[split] % batch_size
            print("GENERATION: Fold {}, {} set has {} queries. Truncating to {} queries.".format(fold_num, split, file_counters[split], file_counters[split] - extra_examples))
            file_counters[split] -= extra_examples
        else:
            print("GENERATION: Fold {}, {} set has {} queries.".format(fold_num, split, file_counters[split]))
        
        # For every query in the file (until truncation), annotate it & write it
        
        paragraphs = []
        for query in temp_files[split]:
            if (file_counters[split] > 0):
                # Remove the newline
                query = query.replace("\n", "")
                # Annotate the query
                q_count  += 1
                data = csr_annotate_query(query, q_count)
                # Write it to the appropriate file
                paragraphs.append(data)
                file_counters[split] -= 1
        # group['paragraphs'] = paragraphs
        # data={}
        # data['data'] = [group]
        # with open("src/utils/data_generation/sample.json", "w") as outfile:
        
        json.dump(paragraphs, file_writers[split])
        data_for_ret.extend(paragraphs)
        # Close the reader and writer
        temp_files[split].close()
        file_writers[split].close()

        # Delete the temporary file
        os.remove(fold_path + "{}-temp.txt".format(split))
    if batch_size == -1:
        return data_for_ret


def get_from_raw_csr(schema_name):
    
    paragraphs = []
    # Initialize the file paths
    schemas = ['flight_delay', 'student_perf', 'online_delivary'] if schema_name == 'combined' else [schema_name]
    for schema in schemas:
        base_path = f"{ARTIFICIAL_DATA_LOC}/csr/{schema}/"
        file_reader = ""
        fold_path = ""
        
        fold_path = base_path + "small/"
        file_reader = open(base_path + "raw/small.txt", "r", encoding="utf8")
        # Make the directory if it doesn't exist
        if not os.path.exists(fold_path):
            os.makedirs(fold_path)

            

        # For each query in the file given, label it
        
        q_count = 0
        for query in file_reader:
            # Randomly assign a label to the query
            query = query.replace("\n", "")
                    # Annotate the query
            q_count  += 1
            data = csr_annotate_query(query, q_count)
            # Write it to the appropriate file
            paragraphs.append(data)

    return paragraphs
    # Close the reader
   

def csr_annotate_query(query, q_count):
    # Declare key parts-of-speech
    tokenized_query = tokenize_query(query)
    
    concepts = []
    type_maps = []
    context_query = query.replace("vids-atr(", "").replace("vids-ent(", "").replace("vids-agg(", "").replace("vids-flt(", "").replace("vids-flo(", "").replace("vids-num(", "").replace("vids-prw(", "").replace("(","").replace(")", "").replace("_", " ").lower()

    # For each token in the query, tag it appropriately
    for token in tokenized_query:
        
        # If the token is a placeholder, remove the placeholder brackets
        if ("<" in token) or (">" in token):
            token = re.sub("[<>]", "", token)

        # If the token has a schema flag, tag it properly
        if "vids-" in token:
            # Based on which ID it has, tag the token appropriately
            token_type = token[5:8]
            extracted_token = token[9:]
            extracted_token = extracted_token[:-1]
            
            answer_text = extracted_token.replace("_", " ").replace("(","").replace(")","").lower()            
            type_map = (token_type, answer_text)
            # type_map[token_type] = answer_text
            type_maps.append(type_map)
            concepts.extend(answer_text.split())

    # random.shuffle(concepts)
    data = {
        "concept_set_idx": q_count,
        "concepts": concepts,
        "target": context_query,
        "type_maps": type_maps
    }
    
    # with open("src/utils/data_generation/sample.json", "w") as outfile:
    #     json.dump(passage, outfile)
    # Return the finished query
    return data