# SPDX-License-Identifier: Apache-2.0
import enum
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from unittest.mock import patch

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
# TPU XLA related
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
                                               PallasMetadata)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

if TYPE_CHECKING:
    from vllm.v1.core.scheduler import SchedulerOutput

logger = init_logger(__name__)

# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000


class ExecutionMode(enum.Enum):
    PREFILL = enum.auto()
    DECODE = enum.auto()
    PREFIX_PREFILL = enum.auto()

    def is_prefill(self) -> bool:
        return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)


@dataclass
class PromptDecodeInfo:
    prompt_req_ids: List[str]
    decode_req_ids: List[str]
    prompt_scheduled_tokens: List[int]


@dataclass
class PromptData:
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: PallasMetadata


@dataclass
class DecodeData:
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    attn_metadata: Optional[PallasMetadata] = None


class TPUModelRunner:

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
        self.device_config = vllm_config.device_config

        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
        self.device = device
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype

        self.is_multimodal_model = model_config.is_multimodal_model
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
        self.max_num_reqs = scheduler_config.max_num_seqs

        # Model-related.
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
        self.hidden_size = model_config.get_hidden_size()

        self.model: Optional[nn.Module] = None

        # Persistent batch.
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
            vocab_size=self.model_config.get_vocab_size(),
        )

        # Request states.
        self.requests: Dict[str, CachedRequestState] = {}

        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}

        # KV caches for forward pass
        self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []

        # Cached torch/numpy tensors
        self.num_swaps = 2
        self.cur_swap_id = 0
        self.input_ids_cpu = []
        self.input_ids_np = []
        self.input_positions_cpu = []
        self.input_positions_np = []
        self.slot_mapping_cpu = []
        self.slot_mapping_np = []
        self.prompt_context_lens_cpu = []
        self.prompt_effective_query_lens_cpu = []
        self.decode_context_lens_cpu = []
        self.decode_context_lens_np = []
        for _ in range(self.num_swaps):
            self.input_ids_cpu.append(
                torch.empty(self.max_num_tokens,
                            dtype=torch.int32,
                            device="cpu"))
            self.input_ids_np.append(self.input_ids_cpu[-1].numpy())

            self.input_positions_cpu.append(
                torch.empty(self.max_num_tokens,
                            dtype=torch.int32,
                            device="cpu"))
            self.input_positions_np.append(
                self.input_positions_cpu[-1].numpy())

            self.slot_mapping_cpu.append(
                torch.empty(self.max_num_tokens,
                            dtype=torch.int64,
                            device="cpu"))
            self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())

            self.prompt_context_lens_cpu.append(
                torch.empty((1), dtype=torch.int32, device="cpu"))
            self.prompt_effective_query_lens_cpu.append(
                torch.empty((1), dtype=torch.int32, device="cpu"))

            self.decode_context_lens_cpu.append(
                torch.empty(self.max_num_tokens,
                            dtype=torch.int32,
                            device="cpu"))
            self.decode_context_lens_np.append(
                self.decode_context_lens_cpu[-1].numpy())

        # Range tensor with values [0 .. self.max_num_tokens - 1].
        # Used to initialize positions / context_lens / seq_lens
        self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)

    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
            True if there is a new/resumed/paused/finished request in the batch.
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)

        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        removed_req_indices: List[int] = []
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            assert req_index is not None
            removed_req_indices.append(req_index)

        req_ids_to_add: List[str] = []
        # Add new requests to the cached states.
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
                sampling_params=sampling_params,
                generator=generator,
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                output_token_ids=[],
                lora_request=new_req_data.lora_request,
            )

            req_ids_to_add.append(req_id)

        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
            req_state = self.requests[req_id]

            # Update the cached states.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)
            start_index = len(req_state.block_ids) - len(
                req_data.new_block_ids)
            self.input_batch.block_table.append_row(req_index, start_index,
                                                    req_data.new_block_ids)

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
        return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0

    def swap_step(self):
        self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps

    def get_model(self) -> nn.Module:
        assert self.model is not None
        return self.model

    def get_kv_cache_spec(self) -> KVCacheSpec:
        """
        Generates the KVCacheSpec by parsing the kv cache format from each 
        Attention module in the static forward context.
        Returns:
            KVCacheSpec: A dictionary mapping layer names to their KV cache 
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
        kv_cache_spec: KVCacheSpec = {}
        for layer_name, attn_module in forward_ctx.items():
            # TODO: Support other attention modules, e.g., sliding window,
            # cross-attention, MLA.
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
                kv_cache_spec[layer_name] = FullAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
                    dtype=attn_module.dtype,
                )
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec

    def _get_prompts_and_decodes(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> PromptDecodeInfo:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # Traverse decodes first
        decode_req_ids = []
        for i in range(num_reqs):
            req_id = self.input_batch.req_ids[i]
            assert req_id is not None

            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
            num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]

            if num_computed_tokens < num_prompt_tokens:
                # This is prompt
                break

            # This is decode
            assert num_scheduled_tokens == 1
            decode_req_ids.append(req_id)

        # Traverse prompts
        prompt_req_ids = []
        prompt_scheduled_tokens = []
        for i in range(len(decode_req_ids), num_reqs):
            req_id = self.input_batch.req_ids[i]
            assert req_id is not None

            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
            num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]

            # Must be prompt
            assert num_computed_tokens < num_prompt_tokens

            prompt_req_ids.append(req_id)
            prompt_scheduled_tokens.append(num_scheduled_tokens)

        return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
                                prompt_scheduled_tokens)

    def _prepare_prompt(self, req_index: int,
                        num_scheduled_tokens: int) -> PromptData:
        num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
            req_index]
        num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]

        # Must be prompt
        assert num_computed_tokens < num_prompt_tokens

        # Prompt len
        prompt_len = num_scheduled_tokens
        padded_prompt_len = _get_padded_prompt_len(prompt_len)
        assert padded_prompt_len <= self.max_model_len

        # Seq len
        seq_len = num_computed_tokens + prompt_len
        padded_seq_len = num_computed_tokens + padded_prompt_len

        # Input tokens
        input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
            req_index, num_computed_tokens:padded_seq_len]
        input_tokens_cpu[prompt_len:] = 0

        # Input positions
        input_positions_np = self.input_positions_np[
            self.cur_swap_id][:padded_prompt_len]
        np.add(num_computed_tokens,
               self.arange_np[:padded_prompt_len],
               out=input_positions_np)
        input_positions_np[prompt_len:] = 0

        # Slot mapping
        block_table_np = \
            self.input_batch.block_table.get_numpy_array()
        block_numbers_np = block_table_np[req_index, input_positions_np //
                                          self.block_size]
        block_offsets_np = input_positions_np % self.block_size

        slot_mapping_np = self.slot_mapping_np[
            self.cur_swap_id][:padded_prompt_len]
        np.add(block_numbers_np * self.block_size,
               block_offsets_np,
               out=slot_mapping_np)
        slot_mapping_np[prompt_len:] = _PAD_SLOT_ID

        # Block table
        block_table_cpu = None
        if num_computed_tokens > 0:
            block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
            block_table_cpu = block_table_cpu[req_index]

        # Context len
        self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
        if num_computed_tokens > 0:
            self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len

        # Effective query len
        self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len

        # Get final tensors
        input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
        input_positions = self.input_positions_cpu[
            self.cur_swap_id][:padded_prompt_len].reshape(1,
                                                          -1).to(self.device)
        slot_mapping = self.slot_mapping_cpu[
            self.cur_swap_id][:padded_prompt_len].reshape(1,
                                                          -1).to(self.device)
        block_table = block_table_cpu.reshape(1, -1).to(
            self.device) if block_table_cpu is not None else None

        context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
            self.device)
        effective_query_lens = self.prompt_effective_query_lens_cpu[
            self.cur_swap_id].to(self.device)

        self.swap_step()

        # Attn metadata
        attn_metadata = PallasMetadata(
            num_prefills=1,
            num_prefill_tokens=0,  # NOTE: This is not used.
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            block_tables=block_table,
            context_lens=context_lens,
            effective_query_lens=effective_query_lens,
        )

        return PromptData(input_tokens, input_positions, attn_metadata)

    def _prepare_decode(
        self,
        decode_req_ids: List[str],
    ) -> DecodeData:
        # Batch size
        batch_size = len(decode_req_ids)
        padded_batch_size = _get_padded_batch_size(batch_size)
        assert padded_batch_size <= self.max_model_len

        # Init [0 .. batch_size - 1]
        req_indices_np = self.arange_np[:padded_batch_size]

        # Input positions
        input_positions_np = self.input_positions_np[
            self.cur_swap_id][:padded_batch_size]
        np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
               0,
               out=input_positions_np)
        input_positions_np[batch_size:] = 0
        input_positions_cpu = self.input_positions_cpu[
            self.cur_swap_id][:padded_batch_size]

        # Input tokens
        token_indices_np = (
            input_positions_np +
            req_indices_np * self.input_batch.token_ids_cpu.shape[1])
        input_tokens_cpu = self.input_ids_cpu[
            self.cur_swap_id][:padded_batch_size]
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
                           0,
                           torch.from_numpy(token_indices_np),
                           out=input_tokens_cpu)
        input_tokens_cpu[batch_size:] = 0

        # Slot mapping
        block_table_indices_np = (
            req_indices_np * self.max_num_blocks_per_req +
            input_positions_np // self.block_size)

        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()

        block_numbers_np = block_table_cpu.flatten(
        )[block_table_indices_np].numpy()

        block_offsets_np = input_positions_np % self.block_size

        slot_mapping_np = self.slot_mapping_np[
            self.cur_swap_id][:padded_batch_size]
        np.add(block_numbers_np * self.block_size,
               block_offsets_np,
               out=slot_mapping_np)
        slot_mapping_np[batch_size:] = _PAD_SLOT_ID

        block_table_cpu = block_table_cpu[:padded_batch_size]

        # Context lens
        context_lens_np = self.decode_context_lens_np[
            self.cur_swap_id][:padded_batch_size]
        np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
               1,
               out=context_lens_np)
        context_lens_np[batch_size:] = 0

        # Get final tensors
        input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
        input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
        slot_mapping = self.slot_mapping_cpu[
            self.cur_swap_id][:padded_batch_size].reshape(-1,
                                                          1).to(self.device)
        block_table = block_table_cpu.to(self.device)
        context_lens = self.decode_context_lens_cpu[
            self.cur_swap_id][:padded_batch_size].to(self.device)

        self.swap_step()

        # Attn metadata
        attn_metadata = PallasMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=padded_batch_size,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            block_tables=block_table,
            context_lens=context_lens,
            effective_query_lens=None,
        )

        return DecodeData(input_tokens=input_tokens,
                          input_positions=input_positions,
                          attn_metadata=attn_metadata)

    @torch.no_grad()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> ModelRunnerOutput:
        # Update cached state
        self._update_states(scheduler_output)

        # If necessary, swap decodes/prompts to have all decodes on the start
        ensure_decodes_first(self.input_batch)

        # Prepare prompts/decodes info
        pd_info = self._get_prompts_and_decodes(scheduler_output)

        # Init
        num_prompts = len(pd_info.prompt_req_ids)
        num_decodes = len(pd_info.decode_req_ids)
        decode_data = None
        sampled_token_ids = [0] * self.input_batch.num_reqs

        # Run each prompt individually
        is_first = True
        for i in range(num_prompts):
            req_id = pd_info.prompt_req_ids[i]
            req_index = num_decodes + i
            assert req_index == self.input_batch.req_id_to_index[
                req_id]  # TODO: Remove
            req_state = self.requests[req_id]
            num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
            prompt_len = num_scheduled_tokens
            seq_len = req_state.num_computed_tokens + num_scheduled_tokens

            # Prepare first prompt
            if is_first:
                prompt_data = self._prepare_prompt(req_index,
                                                   num_scheduled_tokens)
                is_first = False

            # Run forward pass
            with set_forward_context(prompt_data.attn_metadata,
                                     self.vllm_config):
                assert self.model is not None
                selected_token_ids = self.model(prompt_data.input_tokens,
                                                prompt_data.input_positions,
                                                prompt_data.attn_metadata,
                                                self.kv_caches)

            # In parallel to TPU execution, prepare the next iteration
            if i < num_prompts - 1:
                # There is next prompt => prepare it
                prompt_data = self._prepare_prompt(
                    req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
            elif i == num_prompts - 1 and num_decodes > 0:
                # There is next decode => prepare it
                decode_data = self._prepare_decode(pd_info.decode_req_ids)

            # Update cached state (if prompt is fully done)
            if seq_len >= len(req_state.prompt_token_ids):
                # Transfer sampled tokens from TPU to CPU
                selected_token_ids_cpu = selected_token_ids.cpu()

                # Get output token
                token_id = selected_token_ids_cpu[prompt_len - 1].item()
                sampled_token_ids[req_index] = token_id

                # Add output token to the request
                self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
                self.input_batch.num_tokens[req_index] += 1
                req_state.output_token_ids.append(token_id)

        # Run decodes (a single batch)
        if num_decodes > 0:

            # Prepare decode (if was not yet prepared)
            if decode_data is None:
                decode_data = self._prepare_decode(pd_info.decode_req_ids)

            # Run forward pass
            with set_forward_context(decode_data.attn_metadata,
                                     self.vllm_config):
                assert self.model is not None
                selected_token_ids = self.model(decode_data.input_tokens,
                                                decode_data.input_positions,
                                                decode_data.attn_metadata,
                                                self.kv_caches)

            # Transfer sampled tokens from TPU to CPU
            decode_token_ids_cpu = selected_token_ids.cpu()
            # Convert to list
            decode_token_ids_list = decode_token_ids_cpu.tolist()

            # Update cached state for each decode request
            for i in range(num_decodes):
                req_id = pd_info.decode_req_ids[i]
                req_index = i
                assert req_index == self.input_batch.req_id_to_index[
                    req_id]  # TODO: Remove
                req_state = self.requests[req_id]
                seq_len = req_state.num_computed_tokens + 1

                token_id = decode_token_ids_list[i]
                sampled_token_ids[req_index] = token_id

                self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
                self.input_batch.num_tokens[req_index] += 1
                req_state.output_token_ids.append(token_id)

        # Create output.
        all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids
        prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
        for req_id in all_req_ids:
            prompt_logprobs_dict[req_id] = None

        model_runner_output = ModelRunnerOutput(
            req_ids=all_req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
            spec_token_ids=None,
            logprobs=None,
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
        )

        return model_runner_output

    def load_model(self) -> None:
        self.device = self.device_config.device

        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
        xm_tp_rank = xr.global_ordinal()
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
            model = get_model(vllm_config=self.vllm_config)
        model = model.eval()
        xm.mark_step()
        xm.wait_device_ops()
        model = ModelWrapperV1(model)
        self.model = torch.compile(model,
                                   backend="openxla",
                                   fullgraph=True,
                                   dynamic=False)

    def dummy_run(
        self,
        kv_caches,
        num_tokens: int,
        seq_len: Optional[int] = None,
        exec_mode: Optional[ExecutionMode] = None,
    ) -> None:
        assert seq_len is not None
        assert exec_mode is not None

        exec_mode = ExecutionMode(exec_mode)
        if exec_mode.is_prefill():
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((num_tokens, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((num_tokens, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((num_tokens, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            if exec_mode == ExecutionMode.PREFILL:
                attn_metadata = PallasMetadata(
                    num_prefills=num_tokens,
                    num_prefill_tokens=num_tokens * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
                    enable_kv_scales_calculation=True,
                    block_tables=None,
                    context_lens=None,
                    effective_query_lens=None,
                )

            else:
                context_lens = torch.ones((num_tokens, ),
                                          dtype=torch.int32,
                                          device=self.device)

                block_tables = torch.zeros(
                    (num_tokens, self.max_num_blocks_per_req),
                    dtype=torch.int32,
                    device=self.device)

                effective_query_lens = torch.ones_like(context_lens)

                attn_metadata = PallasMetadata(
                    num_prefills=num_tokens,
                    num_prefill_tokens=num_tokens * seq_len,
                    num_decode_tokens=0,
                    slot_mapping=slot_mapping,
                    multi_modal_placeholder_index_maps=None,
                    enable_kv_scales_calculation=True,
                    block_tables=block_tables,
                    context_lens=context_lens,
                    effective_query_lens=effective_query_lens,
                )
        else:
            assert seq_len == 1
            token_ids = torch.zeros((num_tokens, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((num_tokens, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((num_tokens, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (num_tokens, self.max_num_blocks_per_req),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((num_tokens, ),
                                      dtype=torch.int32,
                                      device=self.device)
            attn_metadata = PallasMetadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=num_tokens * seq_len,
                slot_mapping=slot_mapping,
                multi_modal_placeholder_index_maps=None,
                enable_kv_scales_calculation=True,
                block_tables=block_tables,
                context_lens=context_lens,
            )

        # NOTE(woosuk): There are two stages of compilation: torch.compile and
        # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
        # overhead by reusing the FX graph for different shapes.
        # However, the XLA graph will still require static shapes and needs to
        # be re-compiled for every different shapes. This overhead is inevitable
        # in the first run, but can be skipped afterwards as we cache the XLA
        # graphs in the disk (VLLM_XLA_CACHE_PATH).
        if exec_mode.is_prefill():
            # Prefll
            torch._dynamo.mark_dynamic(token_ids, 1)
            torch._dynamo.mark_dynamic(position_ids, 1)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
        else:
            # Decode
            torch._dynamo.mark_dynamic(token_ids, 0)
            torch._dynamo.mark_dynamic(position_ids, 0)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
            torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)

        with set_forward_context(attn_metadata, self.vllm_config, 0):
            assert self.model is not None
            self.model(token_ids, position_ids, attn_metadata, kv_caches)

    def capture_model(self) -> None:
        """Compile the model."""

        # Prefill
        logger.info(
            "Compiling the model with different input shapes for prefill:")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
            while seq_len <= self.model_config.max_model_len:
                self.dummy_run(self.kv_caches,
                               batch_size,
                               seq_len,
                               exec_mode=ExecutionMode.PREFILL)
                xm.wait_device_ops()
                logger.info("  batch_size: %d, seq_len: %d", batch_size,
                            seq_len)
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("    -- Compilation for prefill done in %.2f [secs].",
                    end - start)

        # Prefix prefill
        if self.scheduler_config.enable_chunked_prefill:
            logger.info("Compiling the model with different input shapes for "
                        "prefix prefill:")
            start = time.time()
            for batch_size in [1]:
                seq_len = 16
                while seq_len <= self.model_config.max_model_len:
                    self.dummy_run(self.kv_caches,
                                   batch_size,
                                   seq_len,
                                   exec_mode=ExecutionMode.PREFIX_PREFILL)
                    xm.wait_device_ops()
                    logger.info("  batch_size: %d, seq_len: %d", batch_size,
                                seq_len)
                    num_tokens = batch_size * seq_len
                    if (num_tokens
                            >= self.scheduler_config.max_num_batched_tokens):
                        break
                    seq_len = seq_len * 2
            end = time.time()
            logger.info(
                "    -- Compilation for prefix prefill done in %.2f [secs].",
                end - start)

        # Decode
        logger.info(
            "Compiling the model with different input shapes for decode:")
        start = time.time()
        seq_len = 1
        batch_size = 8  # Must be in sync with _get_padded_batch_size()
        while True:
            self.dummy_run(self.kv_caches,
                           batch_size,
                           seq_len,
                           exec_mode=ExecutionMode.DECODE)
            xm.wait_device_ops()
            logger.info("  batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("    -- Compilation for decode done in %.2f [secs].",
                    end - start)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV 
            cache size of each layer
        """
        if len(kv_cache_config.groups) > 1:
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

        kv_caches: Dict[str, torch.Tensor] = {}

        for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
            tensor_config = kv_cache_config.tensors[layer_name]
            assert tensor_config.size % layer_spec.page_size_bytes == 0
            num_blocks = tensor_config.size // layer_spec.page_size_bytes
            if isinstance(layer_spec, FullAttentionSpec):
                kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
                    num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
                    layer_spec.head_size)
                dtype = layer_spec.dtype

                tpu_k_cache = torch.zeros(kv_cache_shape,
                                          dtype=dtype,
                                          device=self.device)
                tpu_v_cache = torch.zeros_like(tpu_k_cache)

                kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
            else:
                raise NotImplementedError

        bind_kv_cache(
            kv_caches,
            self.vllm_config.compilation_config.static_forward_context,
            self.kv_caches)


class ModelWrapperV1(nn.Module):

    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            attn_metadata: The Pallas attention metadata.
            input_lens: The actual input lengths of shape [batch_size].
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
            num_samples: Number of samples to draw from each logits vector.
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
        """
        # Skip this in memory profiling at initialization.
        if attn_metadata is not None and kv_caches[0][0].numel() > 0:
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

        assert self.model is not None
        hidden_states = self.model(
            token_ids,
            position_ids,
            kv_caches,
            attn_metadata,
        )

        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, None)

        # Greedy sampling.
        argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
        argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
        return argmax_token_ids


def swap_positions(b: InputBatch, id_1, id_2):
    assert id_1 != id_2
    req_id_1 = b.req_ids[id_1]
    req_id_2 = b.req_ids[id_2]
    assert req_id_1 is not None
    assert req_id_2 is not None
    assert id_1 == b.req_id_to_index[req_id_1]
    assert id_2 == b.req_id_to_index[req_id_2]

    b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
    b.req_id_to_index[req_id_1], b.req_id_to_index[
        req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]

    ids = [id_1, id_2]
    rev_ids = [id_2, id_1]
    b.num_tokens[ids] = b.num_tokens[rev_ids]
    b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
    b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
    b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]

    b.block_table.swap_row(id_1, id_2)

    b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
    b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
    b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
    b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
    b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
    b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]

    b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
        id_1]

    gen_1 = b.generators.pop(id_1, None)
    gen_2 = b.generators.pop(id_2, None)
    if gen_1 is not None:
        b.generators[id_2] = gen_1
    if gen_2 is not None:
        b.generators[id_1] = gen_2


def ensure_decodes_first(b: InputBatch):
    num_reqs = b.num_reqs
    while True:
        # Find the first prompt index
        first_prompt_index = None
        for i in range(num_reqs):
            if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
                first_prompt_index = i
                break
        if first_prompt_index is None:
            break

        # Find the last decode index
        last_decode_index = None
        for i in reversed(range(num_reqs)):
            if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
                last_decode_index = i
                break
        if last_decode_index is None:
            break

        # Sanity
        assert first_prompt_index != last_decode_index

        # Check if done
        if first_prompt_index > last_decode_index:
            break

        # Swap
        swap_positions(b, first_prompt_index, last_decode_index)


def _get_padded_prompt_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
        return 8
    else:
        return ((batch_size + 15) // 16) * 16
