import sys
import os
import json
import spacy

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.json_reader import jsonl_loader, json_loader
import re

def remove_title(json_list):

    new_json_list = []
    total_processed = 0
    total_unprocessed = 0
    for i, line in enumerate(json_list):
        answer_split = line["answer_split"]

        if isinstance(answer_split[0], list):
            # flatten the list
            answer_split[0][0] = answer_split[0][0].split('\n', 1)[1] if '\n' in answer_split[0][0] else answer_split[0][0]

            if answer_split[0][0].strip().endswith('.') or answer_split[0][0].strip().endswith('!') or answer_split[0][0].strip().endswith('?'):
                total_unprocessed += 1
            else:
                if len(answer_split[0][0].split())<=8 or answer_split[0][0].strip().endswith(',') or answer_split[0][0].strip().endswith(':'):
                    total_processed += 1
                    answer_split[0].pop(0)
            
            if len(answer_split[0]) == 0:
                answer_split.pop(0)
            
            answer_split[0][0] = remove_repeated_prefix_once(answer_split[0][0])
            answer = "\n\n".join(" ".join(paragraph) for paragraph in answer_split)
            new_line = line
            new_line["answer"] = answer
            new_line["answer_split"] = answer_split
            new_line["original_answer"] = " ".join([" ".join(paragraph) for paragraph in answer_split])
            new_json_list.append(new_line)

        if isinstance(answer_split[0], str):
            answer_split[0] = answer_split[0].split('\n', 1)[1] if '\n' in answer_split[0] else answer_split[0]
            if answer_split[0].strip().endswith('.') or answer_split[0].strip().endswith('!') or answer_split[0].strip().endswith('?'):
                total_unprocessed += 1
            else:
                if len(answer_split[0].split())<=8 and len(answer_split)>1:
                    total_processed += 1
                    answer_split.pop(0)

            answer_split[0] = remove_repeated_prefix_once(answer_split[0])
            answer = " ".join(answer_split)
            new_line = line
            new_line["answer"] = answer
            new_line["answer_split"] = answer_split
            new_line["original_answer"] = " ".join(answer_split)
            new_json_list.append(new_line)

    print(f"Total processed: {total_processed}")
    print(f"Total unprocessed: {total_unprocessed}")

    return new_json_list

def remove_repeated_prefix_once(text, max_n=6):
    words = text.split()
    length = len(words)

    for n in range(1, min(max_n, length // 2) + 1):
        # Get first and second n-grams
        first_ngram = words[:n]
        second_ngram = words[n:2*n]
        
        # If they match, remove the first occurrence
        if first_ngram == second_ngram:
            return ' '.join(words[n:])
    
    return text

if __name__ == '__main__':
    jsonl_input = "dataset/natural_question/natural_question_02.jsonl"
    json_list = json_loader(jsonl_input)
    new_json_list = remove_title(json_list)

    for i, line in enumerate(new_json_list):
        line["word_count"] = len(line["answer"].split())

    with open("dataset/natural_question/natural_question_02_title.jsonl", "w", encoding="utf-8") as w:
        json.dump(new_json_list, w, indent=2, ensure_ascii=False)
    # jsonl_output = "dataset/natural_question/natural_question_02_title.jsonl"
    # remove_title(jsonl_input, jsonl_output)
