import os
import re
import json
import string
import argparse
import numpy as np
from tqdm import tqdm

from prompts import PROMPTS
import torch
import tiktoken
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
from utils import *

def inference_vllm(dataset, model, max_tokens, temperature):
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, top_k=32, top_p=1)
    outputs = model.generate([item["prompt"] for item in dataset], sampling_params)
    
    def process(item, idx):
        item["prediction"] = outputs[idx].outputs[0].text
        item["prompt_token_length"] = len(outputs[idx].prompt_token_ids)
        item["pred_ans"] = extract_ans_from_response(item["prediction"])

        f1_score = drqa_metric_max_over_ground_truths(qa_f1_score, item["pred_ans"], item["answers"])
        item["f1_score"] = f1_score if f1_score is not None else 0
        return item

    dataset = dataset.map(process, with_indices=True, num_proc=1)
    return dataset

def load_data(args, file_name, tokenizer):
    test_data_path = f"{args.src_dir_path}/{file_name}.jsonl"
    dataset = load_custom_dataset(test_data_path)

    def process(item):
        context=item['context']
        item["prompt"] = tokenizer.apply_chat_template(
            [{"role": "user", "content": PROMPTS["short2long_qa"].format(context=context, input=item['input'])}],
            add_generation_prompt=True,
            tokenize=False
        )
        return item
    dataset = dataset.map(process, num_proc=8)
    if args.test_max_length > 0:
        dataset = dataset.filter(lambda x: x["context_length"]<=args.test_max_length and x["context_length"]>=args.test_min_length)
    return dataset

def save_to_niah_format(dataset,output_dir_path,model_name):
    tgt_save_dir = f"{output_dir_path}/graph_{model_name}"
    ensure_path_exists(tgt_save_dir)
    for item in dataset:
        context_length = item["context_length"]
        depth_percent = item["depth_percent"]
        tmp_dict = {
            "context_length": context_length,
            "depth_percent": depth_percent,
            "score": item["f1_score"],
            "answer": item["answers"],
            "pre_ans": item["pred_ans"],
            "raw_response": item["prediction"],
        }
        tmp_res_path = f"{tgt_save_dir}/{model_name}_len_{context_length}_depth_{depth_percent}_results.json"
        save_data(tmp_dict, tmp_res_path)

    avg_score = np.mean(dataset["f1_score"])
    avg_score_dict = {
        "avg_score": f"{round(avg_score*100,2)}",
        "num_samples": str(dataset.num_rows),
    }
    save_data(avg_score_dict, f"{output_dir_path}/avg_score.json")

def inference(args):
    model = LLM(model=args.model_name_or_path, tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=True, max_model_len=args.max_model_len)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    
    for test_data_name in args.eval_data_list.split(","):
        print(f"############「{test_data_name}」-[{args.test_min_length}~{args.test_max_length}]###########")
        tmp_save_dir=f"{args.output_path}-{args.test_max_length}/{test_data_name}/{args.model_name}"
        ensure_path_exists(tmp_save_dir)
        prediction_save_path = f"{tmp_save_dir}/prediction"
        avg_score_path = f"{tmp_save_dir}/avg_score.json"
        if os.path.exists(avg_score_path):
            print(avg_score_path)
            continue
        # 1. load data
        dataset = load_data(args, test_data_name, tokenizer)
        print("================example prompt================")
        print(dataset[0]["prompt"])
        print("================example prompt================")
        # 2. inference and evalaute
        dataset = dataset.select(range(10))
        dataset = inference_vllm(dataset, model, max_tokens=args.max_tokens, temperature=args.temperature)
        print("================end inference_vllm================")
        dataset.save_to_disk(prediction_save_path)
        avg_score = np.mean(dataset["f1_score"])
        avg_score_dict = {
            "avg_score": f"{round(avg_score*100,2)}",
            "num_samples": str(dataset.num_rows),
        }
        save_data(avg_score_dict, avg_score_path)
        # 3. save
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--src_dir_path", type=str)
    parser.add_argument("--eval_data_list", type=str)
    parser.add_argument("--test_min_length", type=int, default=0)
    parser.add_argument("--test_max_length", type=int, default=-1)
    parser.add_argument("--output_path", type=str)

    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--tensor_parallel_size", type=int)
    parser.add_argument("--max_model_len", type=int, default=140000)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--temperature", type=float, default=0)
    parser.add_argument("--prompt_version", type=str, default="0406")
    
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()
    inference(args)