import torch
import argparse
import json
import os
import random
from functools import partial
from typing import List, Dict, Union
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import datasets
from utils import load_single_dataset, save_dataset


def is_share_gpt_format(inputs: List):
    return isinstance(inputs[0], List) and\
        isinstance(inputs[0][0], dict) and\
        ("role" in inputs[0][0].keys() or "content" in inputs[0][0].keys())


def is_alpaca_format(inputs: List):
    return isinstance(inputs[0], dict) and "instruction" in inputs[0].keys() and "output" in inputs[0].keys()


def is_prompts(inputs: List):
    return isinstance(inputs[0], str)


def transform_alpaca_to_sharegpt(inputs: List[Dict]):
    rst = []
    for row in tqdm(inputs):
        tmp = []
        if "system" in row.keys():
            tmp.append({"role": "system", "content": row['system']})
        if "history" in row.keys():
            for user_instruction, model_response in row['history']:
                tmp.extend([{"role": "user", "content": user_instruction},
                            {"role": "assistant", "content": model_response}])
        instruction = row['instruction']
        if "input" in row.keys():
            instruction += f"\n{row['input']}"
        tmp.append({"role": "user", "content": instruction})
        rst.append(tmp)
    return rst


def transform_prompt_to_sharegpt(inputs: List[str]):
    return [[{"role": "user", "content": prompt}] for prompt in inputs]


def load_input(data_path: str, prompt_key: str, dataset_split=None):
    dataset = load_single_dataset(data_path, dataset_split)
    input_conversation = dataset[prompt_key]
    if not is_share_gpt_format(input_conversation):
        input_conversation = transform_prompt_to_sharegpt(input_conversation)
    return dataset.to_list(), input_conversation


if __name__ == '__main__':
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Boolean value expected.')

    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--data', type=str, required=True)
    parser.add_argument("--dataset_split", type=str, required=False, default=None)
    parser.add_argument("--prompt_key", type=str, default="input")
    parser.add_argument("--output_key", type=str, default="output")
    parser.add_argument("--output_json", type=str, default=None)
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--begin", type=int, default=None)
    parser.add_argument("--end", type=int, default=None)
    parser.add_argument("--apply_template", type=str2bool, default=True)
    parser.add_argument("--stop", type=str, default=None)
    parser.add_argument("--max_length", type=int, default=1024)
    parser.add_argument("--temperature", type=float, default=0)
    parser.add_argument("--top_p", type=float, default=1)
    parser.add_argument("--top_k", type=int, default=-1)
    parser.add_argument("--presence_penalty", type=float, default=0.0)
    parser.add_argument("--num_outputs", type=int, default=1)
    parser.add_argument("--log_vllm_input", action="store_true")
    parser.add_argument("--add_think_tag", type=str, default=None)
    parser.add_argument("--output_stop_reason", action="store_true")
    parser.add_argument("--output_finish_reason", action="store_true")
    parser.add_argument("--max_num_seqs", type=int, default=1024)
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.9)
    # LoRA parameters
    parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter 路径")
    parser.add_argument("--lora_name", type=str, default="lora", help="LoRA adapter 的名称")
    parser.add_argument("--lora_id", type=str, default="default-lora-id", help="LoRA adapter 的唯一 ID")

    args = parser.parse_args()

    raw_dataset, input_conversation = load_input(args.data, args.prompt_key, args.dataset_split)

    if args.begin is not None:
        raw_dataset = raw_dataset[args.begin:] if args.end is None else raw_dataset[args.begin:args.end]
        input_conversation = input_conversation[args.begin:] if args.end is None else input_conversation[args.begin:args.end]
    elif args.end is not None:
        raw_dataset = raw_dataset[:args.end]
        input_conversation = input_conversation[:args.end]

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
    print("tokenizer", tokenizer)
    prompts = [tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
               for conv in input_conversation]

    if args.add_think_tag:
        prompts = [p + args.add_think_tag.encode('utf-8').decode('unicode_escape') for p in prompts]

    prompts = [p[len(tokenizer.bos_token):] if tokenizer.bos_token and p.startswith(tokenizer.bos_token) else p for p in prompts]
    prompts = [tokenizer.decode(tokenizer.encode(p)[-args.max_length + 1:]) if len(tokenizer.encode(p)) >= args.max_length else p for p in prompts]

    llm = LLM(
        args.model,
        dtype="bfloat16",
        max_num_seqs=args.max_num_seqs,
        tensor_parallel_size=torch.cuda.device_count(),
        max_seq_len_to_capture=args.max_length,
        max_model_len=args.max_length,
        gpu_memory_utilization=args.gpu_memory_utilization,
        enable_lora=bool(args.lora_path),
        max_lora_rank=128 if args.lora_path is not None else None,
    )

    sampling_params = SamplingParams(
        presence_penalty=args.presence_penalty,
        n=args.num_outputs,
        top_k=args.top_k,
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_length,
    )

    lora_request = None
    if args.lora_path:
        lora_request = LoRARequest(
            lora_name=args.lora_name,
            lora_int_id=abs(hash(args.lora_path)) % 65536,
            lora_path=args.lora_path
        )

    outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)

    vllm_inputs = tokenizer.batch_decode([o.prompt_token_ids for o in outputs])
    stop_reasons = [[o.stop_reason for o in output.outputs] if len(output.outputs) > 1 else output.outputs[0].stop_reason for output in outputs]
    finish_reasons = [[o.finish_reason for o in output.outputs] if len(output.outputs) > 1 else output.outputs[0].finish_reason for output in outputs]
    outputs = [[o.text for o in output.outputs] if len(output.outputs) > 1 else output.outputs[0].text for output in outputs]

    for row, output, stop_reason, finish_reason in zip(raw_dataset, outputs, stop_reasons, finish_reasons):
        row[args.output_key] = output
        if args.output_stop_reason:
            row["stop_reason"] = stop_reason
        if args.output_finish_reason:
            row["finish_reasons"] = finish_reason

    if args.log_vllm_input:
        for row, vllm_input in zip(raw_dataset, vllm_inputs):
            row["vllm_input"] = vllm_input

    if args.output_json:
        os.makedirs(os.path.dirname(args.output_json), exist_ok=True)
        with open(args.output_json, "w", encoding="utf-8") as f:
            json.dump(raw_dataset, f, ensure_ascii=False, indent=2)
        print(f"save to {args.output_json}")

"""


CUDA_VISIBLE_DEVICES=0 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 0 --end 56712 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/0_56712_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=1 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 56712 --end 113424 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/56712_113424_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=2 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 113424 --end 170136 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/113424_170136_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=3 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 170136 --end 226848 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/170136_226848_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=4 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 226848 --end 283560 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/226848_283560_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=5 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 283560 --end 340272 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/283560_340272_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=6 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 340272 --end 396984 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/340272_396984_example.json --gpu_memory_utilization 0.95 &
CUDA_VISIBLE_DEVICES=7 ~/verl_cs/.conda/bin/python  ~/verl_cs/scripts/gen_vllm.py --begin 396984 --data  ~/datasets/PRIME-RL-Eurus-2-RL-Data/train_shuffled_math.parquet --dataset_split train --output_finish_reason --num_outputs 5 --prompt_key prompt --output_key responses --model  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/ --max_length 2560 --temperature 1 --output_json  ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/396984_end_example.json --gpu_memory_utilization 0.95 &


"""