import os
import sys
import re
import jsonlines
import random
import json
import time

from random import sample
from argparse import ArgumentParser

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from utils.json_reader import jsonl_loader, json_loader
from utils.set_random_seed import set_random_seed
from module_04_upgrade.prompt_template import prompt_template_1
from transformers import AutoTokenizer

def return_rand_list(selection_list, required_number):

    numbers = selection_list
    selected_numbers = sample(numbers, required_number)  # Picks and stores one random item from the list.
    return selected_numbers

def remove_trailing_period(text):
    if text.endswith('.'):
        return text[:-1]
    return text

def create_qa_dataset(input_dataset, output_dataset, model_path):
    start = time.time()

    # read the content of the evaluation dataset
    json_list = json_loader(input_dataset)

    object = []

    for i in range(len(json_list)):
        
        message = [{"role": "system", "content": "You are an expert assistant who follows user instructions with precision. Always respond accurately and strictly obey all constraints on content, format, style, and wording."}]

        constraint_list = json_list[i]["constraint_instruction"]
        answer = json_list[i]["answer"]
        short_answer = remove_trailing_period(json_list[i]["short_answer"])
        original_answer = json_list[i]["original_answer"]
        new_question = json_list[i]["new_question"].strip()
        question = json_list[i]["question"].strip()
        new_constraintlist = [c for c in constraint_list if c!= "ERROR_404"]
        question_refine = json_list[i]["question_refine"]
        random.shuffle(new_constraintlist)
        
        # select the style of constraint input
        constraint_input = ""
        if len(new_constraintlist) == 0:
            constraint_input = ""
        elif len(new_constraintlist) == 1:
            constraint_input = new_constraintlist[0].strip()
        else:
            constraint_input = " ".join(new_constraintlist) 
    
        # build chat dataset
        if question_refine == True:

            if len(new_constraintlist) == 0:
                message+=[{"role": "user", "content": question}, {"role": "assistant", "content": short_answer+"."}]
                message+=[{"role": "user", "content": new_question}, {"role": "assistant", "content": original_answer}]
            else:
                message+=[{"role": "user", "content": question}, {"role": "assistant", "content": short_answer+"."}]
                message+=[{"role": "user", "content": new_question}, {"role": "assistant", "content": original_answer}]
                message+=[{"role": "user", "content": new_question+" "+constraint_input}, {"role": "assistant", "content": answer}]

        else:

            if len(new_constraintlist) == 0:
                message+=[{"role": "user", "content": new_question}, {"role": "assistant", "content": original_answer}]
            else:
                message+=[{"role": "user", "content": new_question}, {"role": "assistant", "content": original_answer}]
                message+=[{"role": "user", "content": new_question+" "+constraint_input}, {"role": "assistant", "content": answer}]
        object.append(message)

    with jsonlines.open(output_dataset, mode= "w") as w:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        tokenizer.padding_side = "right"
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        for line in object:
            prompt = tokenizer.apply_chat_template(line, tokenize=False, add_generation_prompt=False)
            w.write({"text":prompt})

    end = time.time()
    print(f"Time taken: {end-start}")

if __name__ == "__main__":
    parser= ArgumentParser()
    parser.add_argument("--jsonl_input", type=str, default = "dataset/natural_question/natural_question_04.jsonl", help="Input JSONL file path")
    parser.add_argument("--jsonl_output", type=str, default ="dataset/natural_question/natural_question_finetune_llama.jsonl", help="Output JSONL file path")
    parser.add_argument("--model_path",type = str, default = "models/Llama-2-7b-chat-hf",help="Please define the model path for the tokenizer")
    args = parser.parse_args()

    create_qa_dataset(args.jsonl_input, args.jsonl_output, args.model_path)