import json
import os
import pickle
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
from nextqa import NExTQALoader

# from nextqa.video import , VideoPrompt
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase

SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"

from sglang.bench_serving import (
    download_and_cache_file,
    gen_prompt,
    get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64

# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]

# A list of all the conversations. Each conversation is a list of
# tuples. If multiturn is not enabled, the length of list is 1,
# containing only the first Q&A pair.
# For the shared prefix workload (synthetic, loogle, nextqa), it
# is a list of conversations sharing the same prefix (synthetic,
# doc, video)
SampleOutput = List[List[Tuple[MsgContent, int, int]]]


def common_filter_chat(
    num_requests: int,
    new_dataset: List,
    tokenizer: PreTrainedTokenizerBase,
    min_prompt_len: Optional[int],
    min_output_len: Optional[int],
    max_prompt_len: Optional[int],
    max_output_len: Optional[int],
    fixed_output_len: Optional[int],
) -> SampleOutput:
    # Filter out sequences that are too long or too short
    filtered_dataset: SampleOutput = []
    l = 0
    input_tokens = 0
    output_tokens = 0
    while l < num_requests:
        for i in range(len(new_dataset)):
            if l == num_requests:
                break
            processed = []
            for j in new_dataset[i]:
                # Tokenize the prompts and completions.
                prompt = j[0]
                prompt_token_ids = tokenizer.encode(prompt)
                prompt_len = len(prompt_token_ids)

                completion = j[1]
                completion_token_ids = tokenizer.encode(completion)
                output_len = (
                    len(completion_token_ids)
                    if fixed_output_len is None
                    else fixed_output_len
                )
                if (
                    min_prompt_len is not None
                    and prompt_len < min_prompt_len
                    or min_output_len is not None
                    and output_len < min_output_len
                    or max_prompt_len is not None
                    and prompt_len > max_prompt_len
                    or max_output_len is not None
                    and output_len > max_output_len
                ):
                    # Prune too short sequences.
                    continue
                input_tokens += prompt_len
                output_tokens += output_len
                processed.append((prompt, prompt_len, output_len))
            filtered_dataset.append(processed)
            l += 1

    print(f"#Input tokens: {input_tokens}")
    print(f"#Output tokens: {output_tokens}")
    return filtered_dataset


def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    disable_shuffle: bool = False,
    enable_multiturn: bool = True,
    fixed_output_len: Optional[int] = None,
) -> SampleOutput:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Download sharegpt if necessary
    if not os.path.isfile(dataset_path):
        dataset_path = download_and_cache_file(SHAREGPT_URL)

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]

    # Keep one conversation in one list
    new_dataset = []
    for data in dataset:
        if len(data["conversations"]) % 2 != 0:
            continue
        if data["conversations"][0]["from"] != "human":
            continue
        chat = []
        total_len = 2
        if enable_multiturn:
            total_len = len(data["conversations"])
        for i in range(0, total_len, 2):
            # One user One Assistant
            chat.append(
                (
                    data["conversations"][i]["value"],
                    data["conversations"][i + 1]["value"],
                )
            )
        new_dataset.append(chat)

    if not disable_shuffle:
        # Shuffle the dataset.
        random.shuffle(new_dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: SampleOutput = common_filter_chat(
        num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
    )
    return filtered_dataset


def sample_ultrachat_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    disable_shuffle: bool = False,
    enable_multiturn: bool = True,
    fixed_output_len: Optional[int] = None,
) -> SampleOutput:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset
    dataset = []
    with open(dataset_path) as f:
        while True:
            line = f.readline()
            if not line:
                break
            dataset.append(json.loads(line))

    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["data"]) >= 2]

    # Keep one conversation in one list
    new_dataset = []
    for data in dataset:
        if len(data["data"]) % 2 != 0:
            continue
        chat = []
        total_len = 2
        if enable_multiturn:
            total_len = len(data["data"])
        for i in range(0, total_len, 2):
            # One user One Assistant
            chat.append((data["data"][i], data["data"][i + 1]))
        new_dataset.append(chat)

    # Shuffle the dataset.
    if not disable_shuffle:
        random.shuffle(new_dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: SampleOutput = common_filter_chat(
        num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
    )
    return filtered_dataset


def sample_loogle_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    disable_shuffle: bool = False,
    enable_multiturn: bool = True,
    enable_shared_prefix: bool = False,
    fixed_output_len: Optional[int] = None,
) -> SampleOutput:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

    # Load the dataset
    dataset = []
    with open(dataset_path) as f:
        while True:
            line = f.readline()
            if not line:
                break
            dataset.append(json.loads(line))

    # Keep one conversation in one list
    new_dataset = []
    # TODO: Add shared prefix support for loogle
    # NOTE: Now we preprocess it only for chat
    for data in dataset:
        chat = []
        if (
            "qa_pairs" not in data
            or data["qa_pairs"] == "none"
            or len(data["qa_pairs"]) == 0
        ):
            # If Q is none (for summarization),
            # We add a question for summarization
            # And keep the summary up to 1024 words
            chat.append(
                (
                    "Input: "
                    + data["input"]
                    + " Question: "
                    + "Please summarize the input",
                    data["input"][:1024],
                )
            )
            new_dataset.append(chat)
        else:
            qa_pairs = eval(data["qa_pairs"])
            for i, qa in enumerate(qa_pairs):
                if i == 0 or enable_shared_prefix:
                    # Combine input with the first Q
                    chat.append(
                        ("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
                    )
                elif enable_multiturn:
                    chat.append((qa["Q"], qa["A"]))

            new_dataset.append(chat)

    # Shuffle the dataset.
    if not disable_shuffle:
        random.shuffle(new_dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: SampleOutput = common_filter_chat(
        num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
    )
    return filtered_dataset


def sample_nextqa_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    max_frames: int,  # Specific for video
    model_path: str,
    disable_shuffle: bool = False,
    enable_multiturn: bool = True,  # No multiturn support for now
    backend: str = "sglang-oai",
    chat_template_name: Optional[str] = None,
    fixed_output_len: Optional[int] = None,
) -> SampleOutput:
    """
    Example of messages:
    message = {
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": base64_data}},
            {"type": "text", "text": video.prompt},
        ],
    }
    """

    if fixed_output_len is None:
        fixed_output_len = 4096

    # TODO: Check for multiturn
    dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
    new_dataset = []
    for v in dataset:
        new_dataset.append(v)

    if not disable_shuffle:
        random.shuffle(new_dataset)

    # TODO: prompt len can get from server side
    filtered_dataset = []
    l = 0
    while l < num_requests:
        for i in range(len(new_dataset)):
            if l == num_requests:
                break

            video = new_dataset[i]

            # text prompt
            prompt = video.prompt

            # NOTE: Chat Template is a must for video benchmark because we have to
            # add special image token for later expansion
            if backend == "sglang" or backend == "sglang-native":
                if "chat_template" in tokenizer.init_kwargs:
                    chat_template = get_chat_template(tokenizer.get_chat_template())
                elif chat_template_name is not None:
                    chat_template = get_chat_template(chat_template_name)
                else:
                    chat_template = get_chat_template_by_model_path(model_path)
                prompt = chat_template.image_token + prompt

            prompt_token_ids = tokenizer(prompt).input_ids
            prompt_len = len(prompt_token_ids)
            output_len = fixed_output_len  # max output len, not real output len

            # video input
            base64_data = encode_video_base64(video.path, video.num_frames)

            # NOTE: This will be replaced by the expanded length from the server
            prompt_len += video.num_frames

            # add to content
            content = [
                {"type": "image_url", "image_url": {"url": base64_data}},
                {"type": "text", "text": prompt},
            ]

            filtered_dataset.append([(content, prompt_len, output_len)])
            l += 1
    return filtered_dataset


def sample_random_requests(
    input_len: int,
    output_len: int,
    num_prompts: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
    dataset_path: str,
    disable_shuffle: bool = False,
) -> SampleOutput:

    input_lens = np.random.randint(
        max(int(input_len * range_ratio), 1),
        input_len + 1,
        size=num_prompts,
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio),
        output_len + 1,
        size=num_prompts,
    )

    if True:
        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

        # Download sharegpt if necessary
        if not os.path.isfile(dataset_path):
            dataset_path = download_and_cache_file(SHAREGPT_URL)

        # Load the dataset.
        with open(dataset_path) as f:
            dataset = json.load(f)
        # Filter out the conversations with less than 2 turns.
        dataset = [data for data in dataset if len(data["conversations"]) >= 2]
        # Only keep the first two turns of each conversation.
        dataset = [
            (data["conversations"][0]["value"], data["conversations"][1]["value"])
            for data in dataset
        ]

        if not disable_shuffle:
            # Shuffle the dataset.
            random.shuffle(dataset)

        # Filter out sequences that are too long or too short
        input_requests: SampleOutput = []
        for data in dataset:
            i = len(input_requests)
            if i == num_prompts:
                break

            # Tokenize the prompts and completions.
            prompt = data[0]
            prompt_token_ids = tokenizer.encode(prompt)
            prompt_len = len(prompt_token_ids)

            # Skip empty prompt
            if prompt_len == 0:
                continue

            if prompt_len > input_lens[i]:
                input_ids = prompt_token_ids[: input_lens[i]]
            else:
                ratio = (input_lens[i] + prompt_len - 1) // prompt_len
                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
            prompt = tokenizer.decode(input_ids)
            input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
    else:
        # Sample token ids from random integers. This can cause some NaN issues.
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        input_requests = []
        for i in range(num_prompts):
            prompt = tokenizer.decode(
                [
                    (offsets[i] + i + j) % tokenizer.vocab_size
                    for j in range(input_lens[i])
                ]
            )
            input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])

    print(f"#Input tokens: {np.sum(input_lens)}")
    print(f"#Output tokens: {np.sum(output_lens)}")
    return input_requests


def gen_prompt(tokenizer, token_num):
    """Generate a random prompt of specified token length using tokenizer vocabulary."""
    all_available_tokens = list(tokenizer.get_vocab().values())
    selected_tokens = random.choices(all_available_tokens, k=token_num)
    return tokenizer.decode(selected_tokens)


def get_gen_prefix_cache_path(args, tokenizer):
    """Create cache directory under ~/.cache/sglang/benchmark"""
    cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"

    # Create a unique cache filename based on the generation parameters
    cache_key = (
        f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
        f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
        f"{tokenizer.__class__.__name__}.pkl"
    )
    return cache_dir / cache_key


def sample_generated_shared_prefix_requests(
    num_groups: int,
    prompts_per_group: int,
    system_prompt_len: int,
    question_len: int,
    output_len: int,
    tokenizer: PreTrainedTokenizerBase,
    args,
    disable_shuffle: bool = False,
) -> SampleOutput:
    """Generate benchmark requests with shared system prompts using random tokens and caching."""
    cache_path = get_gen_prefix_cache_path(args, tokenizer)

    # Try to load from cache first
    if cache_path.exists():
        print(f"\nLoading cached generated input data from {cache_path}")
        with open(cache_path, "rb") as f:
            return pickle.load(f)

    print("\nGenerating new input data...")

    # Generate system prompts for each group
    system_prompts = []
    for _ in range(num_groups):
        system_prompt = gen_prompt(tokenizer, system_prompt_len)
        system_prompts.append(system_prompt)

    # Generate questions
    questions = []
    for _ in range(num_groups * prompts_per_group):
        question = gen_prompt(tokenizer, question_len)
        questions.append(question)

    # Combine system prompts with questions
    input_requests = []
    total_input_tokens = 0
    total_output_tokens = 0

    for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
        system_prompt = system_prompts[group_idx]
        input_requests.append([])
        for prompt_idx in tqdm(
            range(prompts_per_group), desc="Generating questions", leave=False
        ):
            question = questions[group_idx * prompts_per_group + prompt_idx]
            full_prompt = f"{system_prompt}\n\n{question}"
            prompt_len = len(tokenizer.encode(full_prompt))
            input_requests[-1].append((full_prompt, prompt_len, output_len))
            total_input_tokens += prompt_len
            total_output_tokens += output_len

    if not disable_shuffle:
        # Shuffle questions
        random.shuffle(input_requests)

    # Print statistics
    print(f"\nGenerated shared prefix dataset statistics:")
    print(f"Number of groups: {num_groups}")
    print(f"Prompts per group: {prompts_per_group}")
    print(f"Total prompts: {len(input_requests) * prompts_per_group}")
    print(f"Total input tokens: {total_input_tokens}")
    print(f"Total output tokens: {total_output_tokens}")
    print(
        f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
    )
    print(
        f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
    )

    # Save to cache
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Caching generated input data to {cache_path}")
    with open(cache_path, "wb") as f:
        pickle.dump(input_requests, f)

    return input_requests


def get_dataset(args, tokenizer):
    if args.dataset_name == "sharegpt":
        input_requests = sample_sharegpt_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            disable_shuffle=args.disable_shuffle,
            enable_multiturn=args.enable_multiturn,
            fixed_output_len=args.fixed_output_len,
        )
    elif args.dataset_name == "ultrachat":
        input_requests = sample_ultrachat_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            disable_shuffle=args.disable_shuffle,
            enable_multiturn=args.enable_multiturn,
            fixed_output_len=args.fixed_output_len,
        )
    elif args.dataset_name == "loogle":
        input_requests = sample_loogle_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            disable_shuffle=args.disable_shuffle,
            enable_multiturn=args.enable_multiturn,
            enable_shared_prefix=args.enable_shared_prefix,
            fixed_output_len=args.fixed_output_len,
        )
    elif args.dataset_name == "nextqa":
        input_requests = sample_nextqa_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            max_frames=args.max_frames,
            model_path=args.model,
            disable_shuffle=args.disable_shuffle,
            enable_multiturn=args.enable_multiturn,
            backend=args.backend,
            chat_template_name=args.chat_template,
            fixed_output_len=args.fixed_output_len,
        )
    elif args.dataset_name == "random":
        input_requests = sample_random_requests(
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            num_prompts=args.num_prompts,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
            dataset_path=args.dataset_path,
        )
    elif args.dataset_name == "generated-shared-prefix":
        input_requests = sample_generated_shared_prefix_requests(
            num_groups=args.gen_num_groups,
            prompts_per_group=args.gen_prompts_per_group,
            system_prompt_len=args.gen_system_prompt_len,
            question_len=args.gen_question_len,
            output_len=args.gen_output_len,
            args=args,
            tokenizer=tokenizer,
        )
    else:
        raise ValueError(f"Unknown dataset: {args.dataset_name}")
    return input_requests
