"""
Response Sampling
"""
import argparse
import time
import random
import os

from functools import partial
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer
from vllm import SamplingParams,LLM
from utils import *

def show_prompt(item):
    print("=====prompt=====")
    if isinstance(item, list):
        # message
        for i in item:
            print(i["role"])
            print(i["content"])
    else:
        print(item)
    print("=====prompt=====")

random.seed(1021)

def process_prompt_for_ruler_qa(example, template={},inference_mode="all"):
    instruction = template["instruction"]
    content = ""
    doc_format = template["evidence_template_str"]
    selected_paragraph_text = example["all_docs"]
    for idx, paragraph in enumerate(selected_paragraph_text):
        cur_content = doc_format.replace("{sentence}", paragraph)
        content += cur_content
    prompt = instruction + "\nInput:\n" + template["context_prefix"] + content + template["context_suffix"] + \
        template["question_prefix"] + example["question"] + template["question_suffix"] + template["answer_prefix"]
    return prompt

def get_conversations(args, test_dataset, tokenizer, template_path=None):
    # load template
    if template_path is not None:
        template = load_json(template_path)
    # format messages
    process_function = partial(process_prompt_for_ruler_qa, template = template, inference_mode = args.inference_mode)
    def func_for_dataset(example):
        prompt = process_function(example)
        example["messages"] = [
                {"role": "user", "content": prompt}
            ]
        return example
    processed_dataset = test_dataset.map(func_for_dataset, num_proc = 8, load_from_cache_file=False)
    conversation_list = processed_dataset["messages"]
    return conversation_list

def inference_temperature_sampling(generator, sampling_params, dataset, args):
    """
    Inference Sampling
    """
    num_iter = args.n // args.n_iter
    print(f"N: {args.n} - N_ITER: {args.n_iter} - Num Iteration: {num_iter}")
    total_outputs = []
    conversation_list = dataset["conversation"]
    
    for _ in range(num_iter):
        outputs = generator.chat(conversation_list, sampling_params)
        total_outputs.append(outputs)

    def process(item, idx):
        item["predictions"] = sum([[output.text for output in outputs[idx].outputs] for outputs in total_outputs], [])
        assert len(item["predictions"]) == args.n
        return item
    
    dataset = dataset.map(process, num_proc=8, with_indices=True)
    return dataset

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="THUDM/chatglm3-6b", help="Model name or path")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size")
    parser.add_argument("--test_dataset_path", type=str, default="data/musique_raw.jsonl", help="Path to the test dataset")
    parser.add_argument("--tgt_save_dir_path", type=str, default="data/musique_raw.jsonl", help="Path to save final dataset")
    
    parser.add_argument("--template_path", type=str, default="template/cot.json", help="Path to the prompt template")
    parser.add_argument("--inference_mode", type=str, default="all_text", help="all text used to inference or only the key information")
    
    parser.add_argument("--max_model_len", type=int, default=128000)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--temperature", type=float, default=0.85)
    parser.add_argument("--n", type=int, default=16)
    parser.add_argument("--n_iter", type=int, default=8)

    return parser.parse_args()


def main():
    args = get_args()
    print(args)
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    # output_path
    model_name = os.path.basename(args.model_name_or_path)
    if args.test_dataset_path.endswith(".jsonl"):
        file_name = os.path.basename(args.test_dataset_path)
        output_path = os.path.join(args.tgt_save_dir_path, file_name.replace(".jsonl", f"_infer_with_{args.inference_mode}_by_{model_name}.jsonl"))
        # load dataset
        test_dataset = load_dataset("json", data_files={"test": args.test_dataset_path}, split="test")
    else:
        file_name = args.test_dataset_path.split("/")[-1]
        output_path = os.path.join(args.tgt_save_dir_path, f"{file_name}_infer_with_{args.inference_mode}_by_{model_name}")
        # load dataset
        test_dataset = load_from_disk(args.test_dataset_path)
    
    # test_dataset = test_dataset.select(range(100))
    conversation_list = get_conversations(args, test_dataset, tokenizer, template_path=args.template_path)
    # show prompt
    show_prompt(conversation_list[0])
    
    test_dataset = test_dataset.add_column("conversation", conversation_list)
    # load vllm models
    generator = LLM(model=args.model_name_or_path, tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=True, max_model_len = args.max_model_len)
    # inference parameters
    sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens, n=args.n_iter)
    # generation
    start_time = time.time()
    result_dataset = inference_temperature_sampling(generator, sampling_params, test_dataset, args)
    end_time_block = time.time()
    execution_time = end_time_block - start_time
    print(f"Execution Time: {execution_time:.2f} seconds")
    # save
    result_dataset = result_dataset.remove_columns("conversation")
    if args.test_dataset_path.endswith(".jsonl"):
        result_dataset.to_json(output_path,force_ascii=False)
    else:
        result_dataset.save_to_disk(output_path)

if __name__ == "__main__":
    main()
