import json
import re
from src.utils.data_generation.tokenization import *
from src.core.configuration.datagen_conf import *

def main(path, write_path):
    file = open(path, 'r')
    writer_file = open(write_path, 'w')

    q_count = QUESTION_NUMBER
    group = {}
    group["title"] = 'hand-crafted'
    paragraphs = []
    for line in file:
        line = line.replace('\n', '').split("|")[0]
        passage, q_count = convert_to_squad(line, q_count)
        paragraphs.append(passage)
    group['paragraphs'] = paragraphs
    data={}
    data['data'] = [group]
    # with open("src/utils/data_generation/sample.json", "w") as outfile:
    json.dump(data, writer_file)
    # Close the reader and writer
    writer_file.close()
    file.close()


def convert_to_squad(query, q_count):
    
    tokenized_query = tokenize_query(query)
    
    question_answer_list = []
    context_query = query.replace("vids-atr(", "").replace("vids-ent(", "").replace("vids-agg(", "").replace("vids-flt(", "").replace("vids-flo(", "").replace("vids-num(", "").replace("vids-prw(", "").replace("(","").replace(")", "").replace("_", " ").lower()
    passage ={}
    question_map = {
        "atr": "what is the target attribute?",
        "agg": "what is the target aggregator?",
        "flt": "what is the target filtering attribute?",
        "flo": "what is the target filtering operator?",
        "prw": "what is the prediction window?",
        "num": "what number is used?",
        "ent": "what is the target entity?"
    }
    
    

    # For each token in the query, tag it appropriately
    for token in tokenized_query:
        token_type = ""
        # Flag which part of the sentence we're on

        # If the token has a schema flag, tag it properly
        if "vids-" in token:
            # Based on which ID it has, tag the token appropriately
            token_type = token[5:8]
            extracted_token = token[9:]
            extracted_token = extracted_token[:-1]

            qa = {}
            question = question_map[token_type]
            answer_text = extracted_token.replace("_", " ").replace("(","").replace(")","").lower()
            answer_start, answer_end = [(match.start(), match.end()) for match in re.finditer(answer_text, context_query)][0]
            answer = {}
            answer['text'] = answer_text
            answer['answer_start'] = answer_start
            answer['answer_end'] = answer_end
            qa['question'] = question
            qa['answers'] = [answer]
            unique_id = hex(q_count)
            qa['id'] = str(unique_id)
            q_count +=1
            question_answer_list.append(qa)
        
    passage['context'] = context_query
    passage['qas'] = question_answer_list

    return passage, q_count



main("src/data/test_data/squad_format/student_perf_annotated.txt", "src/data/test_data/squad_format/student_perf.json")
# main("src/data/test_data/conll_format/student_perf_wo_annotation.txt", "src/data/test_data/conll_format/student_perf.txt", "src/data/fine_tuning/static_attributes.json")