import jsonlines
import re
import time
import sys
import os
import json
import spacy
import string

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from random import sample
from random import sample
from nltk.tokenize import word_tokenize
from argparse import ArgumentParser

from utils.model_loader import VLLM_Model
from module_02_feature.prompt_template import *
from reader.vllm_reader import vllm_reader, vllm_reader_batch
from utils.set_random_seed import set_random_seed
from utils.json_reader import jsonl_loader,json_loader
from utils.free_vllm import cleanup
from module_02_feature.locate_answer import locate_short_answer
from module_02_feature.remove_title import remove_title

def return_rand_list(selection_list, required_number):

    numbers = selection_list
    selected_numbers = sample(numbers, required_number)  # Picks and stores one random item from the list.
    return selected_numbers

def starts_with(text):
    # Check if the text starts with 's' or 're' followed by a whitespace
    return bool(re.match(r'^(s\s|re\s|\s|t\s|ve\s|ll\s)', text))

def extract_quoted_text(text):
    # Regular expression to find text within double or single quotes, allowing contractions like you're
    pattern = r'"([^"]+)"|\'([^\']+)\''
   
    # Extract matches
    matches = re.findall(pattern, text)
    
    # Flatten the list and remove empty strings
    return [match[0] if match[0] else match[1] for match in matches]

def extract_phrases(text):
    # Use regex to find numbered phrases
    phrases = re.findall(r'\d+\.\s([^\n]+)', text)
    for i in range(len(phrases)):
        phrases[i] = phrases[i].strip()
    return phrases

def dataset_preprocess(jsonl_input, jsonl_output):
    print("start preprocessing "+jsonl_input)
    # set random seed before anything else
    set_random_seed(42)

    # define spacy nlp tools
    nlp = spacy.load("en_core_web_sm")

    start = time.time()
    # define the parameters
    path_to_yml = "configs/config.yml"
    model_name = "qwen_32b_model"

    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()
    
    # read the dataset
    json_list = json_loader(jsonl_input)

    # extract keywords and phrases from the answer with llama-8b model
    query_list = []
    system_prompt_list = []

    # Split into paragraph with llama-8b-model
    query_3_list = []
    system_prompt_3_list = []
    long_text = [] # record the id of the long text
    output_3_dict = {}

    # record original word count
    word_count_list = []

    for i in range(len(json_list)): 
        answer = json_list[i]["answer"].replace("\n","")
        question = json_list[i]["question"]

        # Prompt for keywords extraction
        system_prompt = Extract_keywords_system_prompt
        user_prompt =  Extract_keywords_user_prompt.format(text = answer)

        query_list.append(user_prompt)
        system_prompt_list.append(system_prompt)

        # word count
        word_list = word_tokenize(answer)
        word_count = len(word_list)
        word_count_list.append(word_count)

        if word_count>=100:

            system_prompt_3 = paragraph_split_system_prompt
            user_prompt_3 =  paragraph_split_user_prompt.format(text = answer)

            query_3_list.append(user_prompt_3)
            system_prompt_3_list.append(system_prompt_3)
            long_text.append(i)
        
        else:
            output_3_dict[i] = answer
            
    output_list_3 = vllm_reader_batch(model, tokenizer, params, query_3_list, system_prompt_3_list, batch_size = 100)
    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size = 100)

    # extract keywords from the output
    keyword_lists = []
    for string in output_list:
        keyword_list = extract_phrases(string)
        keyword_lists.append(keyword_list)

    for id, string in zip(long_text, output_list_3):
        output_3_dict[id] = string 

    object = []

    for i, line in enumerate(json_list):
        answer = output_3_dict[i]
        question = line["question"]
        
        # =============== only for specific dataset which required question rewrite================
        new_question = ""
        try:
            new_question = return_rand_list(line["new_question"],1)[0]
        except:
            new_question = line["question"]+"?"
        # ========================================================================================== 

        # word count
        word_count = word_count_list[i]

        # sentence count
        sentence_count = 0

        # paragraph count
        paragraph_count = 0

        # split answer into sentence
        answer_split = []
        paragraph_split = answer.split("\n\n")

        if len(paragraph_split) == 1:
            doc = nlp(answer)
            answer_split = [sent.text.strip() for sent in doc.sents]
            sentence_count = len(answer_split)
            paragraph_count = 1

        else:
            for paragraph in answer.split("\n\n"):
                doc = nlp(paragraph)
                temp = [sent.text.strip() for sent in doc.sents]
                answer_split.append(temp)
                sentence_count+=len(temp)
                paragraph_count+=1

        object.append({"id": i+1, "question": question, "new_question": new_question, "answer": answer, "answer_split": answer_split, "word_count": word_count, "sentence_count": sentence_count, "paragraph_count":paragraph_count, "keywords": keyword_lists[i]})

    # free gpu memory
    cleanup(language_model)

    # remove title from answer
    object_title = remove_title(object)

    # locate answer 
    object_locate_answer = locate_short_answer(object_title, model, tokenizer, params)

    # question refine
    rewrite_tag = []
    query_list = []
    system_prompt_list = [] 

    for i, line in enumerate(object_locate_answer):
        id = line["id"]
        answer = line["answer"]
        new_question = line["new_question"]
        answer_split = line["answer_split"]
        short_answer =line["short_answer"]
        word_count = line["word_count"]

        if "direct answer" not in short_answer:
            answer = ""
            if isinstance(answer_split[0],list):
                # flatten the list
                answer_split = [sentence for sublist in answer_split for sentence in sublist]
                answer = " ".join(answer_split)
            
            else:
                answer = " ".join(answer_split)
            
            highlighted_answer = answer
            query = question_refine_prompt.format(answer = highlighted_answer, short_answer = short_answer, question = new_question)
            system_prompt = question_refine_system_prompt
            query_list.append(query)
            system_prompt_list.append(system_prompt)
            rewrite_tag.append(i)

    output_list_4 = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=100)

    object = []

    current_rewrite_id = 0
    for i, line in enumerate(object_locate_answer):
        if i in rewrite_tag:
            new_line = object_locate_answer[i]
            # replace the question with the new question to use in multi turn convo
            new_line["question"] = new_line["new_question"]
            new_line["new_question"] = output_list_4[current_rewrite_id]
            new_line["question_refine"] = True
            object.append(new_line)
            current_rewrite_id += 1

        else:
            object_locate_answer[i]["question_refine"] = False
            object.append(object_locate_answer[i])

    with open(jsonl_output, "w", encoding="utf-8") as w:
        json.dump(object, w, indent=2, ensure_ascii=False)

    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == '__main__':
    parser= ArgumentParser()
    parser.add_argument("--jsonl_input", type=str, default="dataset/natural_question/natural_question_01.jsonl", help="Input JSONL file path")
    parser.add_argument("--jsonl_output", type=str, default="dataset/natural_question/natural_question_02.jsonl", help="Output JSONL file path")
    args = parser.parse_args()
    dataset_preprocess(args.jsonl_input, args.jsonl_output)
