import re
import string
import random 
import jsonlines
import sys 
import os
import json

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from random import sample
from nltk.tokenize import sent_tokenize, word_tokenize
from reader.llama_reader import llama_reader, llama_reader_batch
from reader.vllm_reader import vllm_reader, vllm_reader_batch
from reader.ds2_reader import ds2_reader_batch
from utils.json_reader import jsonl_loader, json_loader
from utils.model_loader import DS2_Model, Model, VLLM_Model
from module_04_upgrade.prompt_template import *
from module_04_upgrade.example import *

def llm_feature(jsonl_input,jsonl_output):
    json_list = json_loader(jsonl_input)

    # define the parameters
    path_to_yml = "configs/config.yml"
    model_name = "qwen_32b_model"

    # load the model
    language_model = VLLM_Model(model_name, path_to_yml)
    model = language_model.load_model()
    tokenizer = language_model.load_tokenizer()
    params = language_model.load_config()

    # =================first stage format rewrite=================
    rewrite_tag = []
    query_list = []
    system_prompt_list = [] 

    for i, line in enumerate(json_list):
        id = line["id"]
        selected_constraints =line["selected_constraints"]
        constraint_instruction = line["constraints_instructions"]
        answer = line["answer"]
        answer_split = line["answer_split"]

        if "9_text_format" in list(selected_constraints.keys()):
            query = answer_rewrite_prompt.format(answer = answer, constraint = constraint_instruction[0], example = eval(selected_constraints["9_text_format"]))
            system_prompt = system_prompt_template
            query_list.append(query)
            system_prompt_list.append(system_prompt)
            rewrite_tag.append(i)

        if "5_document_format" in list(selected_constraints.keys()):
            query = answer_rewrite_prompt.format(answer = answer, constraint = constraint_instruction[0], example = eval(selected_constraints["5_document_format"]))
            system_prompt = system_prompt_template
            query_list.append(query)
            system_prompt_list.append(system_prompt)
            rewrite_tag.append(i)

    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=100)

    current_rewrite_id = 0
    for i, line in enumerate(json_list):
        if i in rewrite_tag:
            json_list[i]["answer"] = output_list[current_rewrite_id]
            json_list[i]["answer_split"] = [output_list[current_rewrite_id]]
            current_rewrite_id += 1
    
    # =================second stage title rewrite=================
    rewrite_tag = []
    query_list = []
    system_prompt_list = [] 

    for i, line in enumerate(json_list):
        id = line["id"]
        selected_constraints =line["selected_constraints"]
        constraint_instruction = line["constraints_instructions"]
        answer = line["answer"]
        answer_split = line["answer_split"]

        title_constraints = ["title", "title_all_caps", "title_no_caps", "title_bracket"]

        if any(item in selected_constraints.values() for item in title_constraints):
            query = title_prompt_template.format(answer = answer)
            system_prompt = system_prompt_template
            query_list.append(query)
            system_prompt_list.append(system_prompt)
            rewrite_tag.append(i)

    output_list = vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size=100)

    current_rewrite_id = 0
    for i, line in enumerate(json_list):
        if i in rewrite_tag:
            json_list[i]["title"] = [output_list[current_rewrite_id]]
            current_rewrite_id += 1
        else:
            json_list[i]["title"] = []

    # =================third stage paragraph title rewrite=================
    rewrite_tag = []
    query_list = []
    system_prompt_list = [] 

    for i, line in enumerate(json_list):
        id = line["id"]
        selected_constraints =line["selected_constraints"]
        constraint_instruction = line["constraints_instructions"]
        answer = line["answer"]
        answer_split = line["answer_split"]

        title_constraints = ["paragraph_title", "paragraph_title_all_caps", "paragraph_title_no_caps", "paragraph_title_enclose"]

        query_sublist = []
        system_prompt_sublist = []

        if any(item in selected_constraints.values() for item in title_constraints):
            for paragraph in answer_split:
                current_paragraph = " ".join(paragraph)
                query = title_prompt_template.format(answer = current_paragraph)
                system_prompt = system_prompt_template
                
                query_sublist.append(query)
                system_prompt_sublist.append(system_prompt)

            query_list.append(query_sublist)
            system_prompt_list.append(system_prompt_sublist)
            rewrite_tag.append(i)

    # record the len od each sublist and match it with the rewrite tag
    sublist_lengths = {}
    for id, query_sublist in zip(rewrite_tag,query_list):
        sublist_lengths[id] = len(query_sublist)
    
    # flatten the query_list and system_prompt_list
    flat_query_list = [item for sublist in query_list for item in sublist]
    flat_system_prompt_list = [item for sublist in system_prompt_list for item in sublist]

    output_list = vllm_reader_batch(model, tokenizer, params, flat_query_list, flat_system_prompt_list, batch_size=100)

    # replace the values of sublist_length with the corresponding output
    output_dict = {}
    current_index = 0
    for id, length in sublist_lengths.items():
        output_dict[id] = output_list[current_index:current_index + length]
        current_index += length

    current_rewrite_id = 0
    for i, line in enumerate(json_list):
        if i in rewrite_tag:
            current_paragraph_titles = output_dict[i]
            json_list[i]["title"] = current_paragraph_titles

    with open(jsonl_output, "w", encoding="utf-8") as w:
        json.dump(json_list, w, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    jsonl_input = "dataset/natural_question/natural_question_03.jsonl"
    jsonl_output = "dataset/natural_question/natural_question_04.jsonl"
    llm_feature(jsonl_input,jsonl_output)