import os

import json
import copy
import random
import argparse
from tqdm import tqdm
# from IPython import embed
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from vllm import LLM, SamplingParams
import torch
torch.manual_seed(0)

os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"


def initialize():
    parser = argparse.ArgumentParser("")
    parser.add_argument("--model_name_or_path", type=str, default='')
    parser.add_argument("--dataset_name", type=str, default='sharegpt')
    parser.add_argument("--max_tokens", type=int, default=1024)
    parser.add_argument("--tensor_parallel_size", type=int, default=torch.cuda.device_count())
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--output_dir", type=str, default='')
    parser.add_argument("--enable_chunked_prefill", action="store_true")
    args = parser.parse_args()
    
    random.seed(42)
    return args

def process_dataset(args, tokenizer):

    data_dirs = {
        "sharegpt_v3": "path/to/sharegpt_V3_format_4k_filtered.jsonl",
        "ultrafeedback": "path/to/ultrafeedback.jsonl"
    }

    prompts = []

    dataset_name = args.dataset_name
    data_dir = data_dirs[dataset_name]
    print(f"Loading {dataset_name}...")
    CNT = 0
    
    with open(data_dir, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            chat = json.loads(line)
            prompt = tokenizer.apply_chat_template(copy.deepcopy(chat[:-1]), tokenize=False, add_generation_prompt=True)

            prompts.append({
                "id": CNT,
                "raw_prompt": chat[:-1],
                "prompt": prompt,
                "label": chat[-1]['content'],
                "source": dataset_name,
            })
            CNT += 1

    print(f"total instance number: {len(prompts)}")
    return prompts


args = initialize()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False, trust_remote_code=True)
prompts = process_dataset(args, tokenizer)


sources = [data['source'] for data in prompts]
labels = [data['label'] for data in prompts]
raw_prompts = [data['raw_prompt'] for data in prompts]
ids =  [data['id'] for data in prompts]
os.makedirs(args.output_dir, exist_ok=True)


sampling_params = SamplingParams(
    temperature=args.temperature,
    max_tokens=args.max_tokens, 
)
llm = LLM(
    model=args.model_name_or_path,
    tensor_parallel_size=args.tensor_parallel_size, 
    trust_remote_code=True,
    enable_chunked_prefill=args.enable_chunked_prefill, # default is False
)

results = []

outputs = llm.generate([data['prompt'] for data in prompts], sampling_params)

for output, source, label, raw_prompt, _id in zip(outputs, sources, labels, raw_prompts, ids):
    prompt = output.prompt
    generated_text = output.outputs[0].text
    results.append({
        "id": _id,
        "raw_prompt": raw_prompt,
        "prompt": prompt,
        "pred": generated_text,
        "label": label,
        "source": source,
        "model": args.model_name_or_path.split('/')[-1].strip()
    })

with open(os.path.join(args.output_dir, f"{args.model_name_or_path.split('/')[-1].strip()}_preds.json"), 'w', encoding='utf-8') as f:
    f.write(json.dumps({
        "cnt": len(results),
        "data": results
    }, indent=4, ensure_ascii=False))
