import json
import random
import argparse
import torch
from typing import List, Tuple, Optional
from transformers import PreTrainedTokenizerBase, AutoTokenizer
import datasets
from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int],
    shuffle_dataset: bool,
    start: int = 0,
) -> List[Tuple[str, int, int]]:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    if dataset_path == "sharegpt":
        with open("datasets/ShareGPT_V3_unfiltered_cleaned_split.json") as f:
            dataset = json.load(f)
        dataset = [data for data in dataset if len(data["conversations"]) >= 2]
        # Load more data to ensure we have enough after filtering
        dataset = dataset[start:start + int(num_requests * 1.2)] 
        ds = dataset

        # Only keep the first two turns of each conversation.
        dataset = [(data["conversations"][0]["value"],
                data["conversations"][1]["value"]) for data in dataset]
        prompts = []
        for prompt, _ in dataset:
            # Format using tokenizer's chat template
            chat = [
                {"role": "user", "content": prompt}
            ]
            formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

            prompts.append(formatted_prompt)
    elif dataset_path == "lmsys":
        dataset = datasets.load_dataset("lmsys/lmsys-chat-1m")['train']
        # Load more data to ensure we have enough after filtering
        ds = dataset.select(range(start, start + int(num_requests * 1.2)))
        prompts = []
        for i, question in enumerate(ds):
            prompt = None
            for convsat in question['conversation']:
                if convsat['role'] == 'user':
                    prompt = convsat['content']
                    break
            if prompt is None:
                continue
            # Format using tokenizer's chat template
            chat = [
                {"role": "user", "content": prompt}
            ]
            formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, add_generation_prompt_token=False)
            prompts.append(formatted_prompt)
    
    prompt_token_ids = tokenizer(prompts).input_ids
    tokenized_dataset = []
    for i in range(len(prompts)):
        output_len = fixed_output_len
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2000000: #only filter too long prompt
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    print(f"Total prompts after filtering: {len(filtered_dataset)}")
    print(f"Number of requests to sample: {num_requests}")

    if len(filtered_dataset) < num_requests:
        raise ValueError(f"Not enough valid prompts after filtering. Got {len(filtered_dataset)}, need {num_requests}")

    # Sample the requests.
    sampled_requests = random.sample(filtered_dataset, num_requests)

    return sampled_requests

def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)

    # Sample the requests.
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=True)
    requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
                             args.output_len, args.shuffle_dataset, args.start)
    prompts = []
    for i in range(len(requests)):
        prompts.append(requests[i][0])

    # Initialize vLLM
    llm = LLM(
        model=args.model,
        tokenizer=args.tokenizer,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed,
        dtype=args.dtype,
        gpu_memory_utilization=args.gpu_memory_utilization,
        load_format="dummy",
    )

    # Run inference
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=1.0,
        max_tokens=args.output_len,
    )
    
    outputs = llm.generate(prompts, sampling_params)

    
    # Save results
    save_file_name = f"{args.dataset}-{args.model[args.model.rfind('/') + 1:]}-t{args.temperature}-s{args.seed}-l{args.output_len}-c{args.num_prompts if args.start == 0 else str(args.start) + ':' + str(args.start + args.num_prompts)}-r{args.shuffle_dataset}.jsonl"

    with open(save_file_name, "w") as outfile:
        for output in outputs:
            result_json = {"prompt": output.prompt, "generated": output.outputs[0].text}
            outfile.write(json.dumps(result_json) + "\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate dataset using vLLM.")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="Path to the dataset.",
                        choices=["sharegpt", "lmsys"])
    parser.add_argument("--output-len",
                        type=int,
                        default=8192,
                        help="Output length for each request.")
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
    parser.add_argument("--num-prompts",
                        type=int,
                        default=20000,
                        help="Number of prompts to process.")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--shuffle-dataset", action="store_true")
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output).')
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations.')
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for the model executor')
    parser.add_argument("--enforce-eager",
                        action="store_true",
                        help="enforce eager execution")
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
        choices=["auto", "fp8"],
        default="auto",
        help='Data type for kv cache storage.')
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors.')
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "cpu"],
        help='device type for vLLM execution')
    parser.add_argument(
        "--enable-prefix-caching",
        action='store_true',
        help="enable automatic prefix caching for vLLM backend.")

    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights')
    args = parser.parse_args()
    if args.tokenizer is None:
        args.tokenizer = args.model
    main(args)