import dataclasses
from typing import List, Tuple, Type

import torch

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
    ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput


class MockAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        raise NotImplementedError

    @staticmethod
    def get_impl_cls():
        raise NotImplementedError

    @staticmethod
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return AttentionMetadata

    @staticmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        return AttentionMetadataBuilder

    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        pass

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
    ) -> None:
        pass


def test_model_runner_input():
    sampling_metadata = SamplingMetadata(
        ["seq_group"],
        "selected_token_indices",
        "categorized_sample_indices",
        "num_prompts",
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
    )
    model_input = ModelInputForGPUWithSamplingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        sampling_metadata=sampling_metadata,
        attn_metadata=attn_metadata)

    assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (
        ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=attn_backend))
    # Check that received copy has correct values.
    assert isinstance(received_model_input,
                      ModelInputForGPUWithSamplingMetadata)
    assert received_model_input.input_tokens is not None
    assert (
        received_model_input.input_tokens == model_input.input_tokens).all()
    assert received_model_input.input_positions is not None
    assert (received_model_input.input_positions == model_input.input_positions
            ).all()
    assert received_model_input.multi_modal_kwargs is None
    assert (received_model_input.multi_modal_kwargs ==
            model_input.multi_modal_kwargs)
    assert received_model_input.lora_requests is None
    assert received_model_input.lora_requests == model_input.lora_requests
    assert received_model_input.lora_mapping is None
    assert received_model_input.lora_mapping == model_input.lora_mapping
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(received_model_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # For sampling metadata, only selected_token_indices is copied.
    assert (received_model_input.sampling_metadata.selected_token_indices ==
            sampling_metadata.selected_token_indices)
    assert received_model_input.sampling_metadata.seq_groups is None


def test_embedding_model_runner_input():
    pooling_metadata = PoolingMetadata(
        seq_groups=[[0]],
        seq_data={},
        prompt_lens=[1],
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
    )
    model_input = ModelInputForGPUWithPoolingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        pooling_metadata=pooling_metadata,
        attn_metadata=attn_metadata)

    assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (
        ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=attn_backend))
    # Check that received copy has correct values.
    assert isinstance(received_model_input,
                      ModelInputForGPUWithPoolingMetadata)
    assert received_model_input.input_tokens is not None
    assert (
        received_model_input.input_tokens == model_input.input_tokens).all()
    assert received_model_input.input_positions is not None
    assert (received_model_input.input_positions == model_input.input_positions
            ).all()
    assert received_model_input.multi_modal_kwargs is None
    assert (received_model_input.multi_modal_kwargs ==
            model_input.multi_modal_kwargs)
    assert received_model_input.lora_requests is None
    assert received_model_input.lora_requests == model_input.lora_requests
    assert received_model_input.lora_mapping is None
    assert received_model_input.lora_mapping == model_input.lora_mapping
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(received_model_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # Pooling metadata is not broadcast.
    assert received_model_input.pooling_metadata is None


def test_multi_step_model_runner_input():
    sampling_metadata = SamplingMetadata(
        ["seq_group"],
        "selected_token_indices",
        "categorized_sample_indices",
        "num_prompts",
    )
    attn_metadata = AttentionMetadata(
        num_prefills=1,
        num_prefill_tokens=2,
        num_decode_tokens=3,
        slot_mapping=torch.zeros(1),
    )
    frozen_model_input = ModelInputForGPUWithSamplingMetadata(
        input_tokens=torch.ones(10),
        input_positions=torch.ones(10),
        sampling_metadata=sampling_metadata,
        attn_metadata=attn_metadata)

    model_input = StatefulModelInput(
        frozen_model_input=frozen_model_input,
        is_last_step=True,
        is_first_multi_step=False,
        current_step=4,
        last_sampled_token_ids=torch.ones((10, 1)),
        is_multi_step=True,
        num_queries=8,
        num_seqs=5,
        cached_outputs=[],
    )

    assert isinstance(model_input, StatefulModelInput)

    # Test round trip serialization.
    tensor_dict = model_input.as_broadcastable_tensor_dict()
    attn_backend = MockAttentionBackend()
    received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
        tensor_dict, attn_backend=attn_backend))

    receieved_frozen_input = received_model_input.frozen_model_input

    # Check that received copy has correct values.
    assert isinstance(received_model_input, StatefulModelInput)
    assert receieved_frozen_input.input_tokens is not None
    assert (receieved_frozen_input.input_tokens ==
            frozen_model_input.input_tokens).all()
    assert receieved_frozen_input.input_positions is not None
    assert (receieved_frozen_input.input_positions ==
            frozen_model_input.input_positions).all()
    assert receieved_frozen_input.multi_modal_kwargs is None
    assert (frozen_model_input.multi_modal_kwargs ==
            frozen_model_input.multi_modal_kwargs)
    assert receieved_frozen_input.lora_requests is None
    assert (receieved_frozen_input.lora_requests ==
            frozen_model_input.lora_requests)
    assert receieved_frozen_input.lora_mapping is None
    assert (
        receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
    for field in dataclasses.fields(AttentionMetadata):
        assert getattr(receieved_frozen_input.attn_metadata, field.name,
                       None) == getattr(attn_metadata, field.name, None)
    # For sampling metadata, only selected_token_indices is copied.
    assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
            sampling_metadata.selected_token_indices)
    assert receieved_frozen_input.sampling_metadata.seq_groups is None

    # check non frozen fields
    assert received_model_input.is_last_step == model_input.is_last_step
    assert (received_model_input.is_first_multi_step ==
            model_input.is_first_multi_step)
    assert received_model_input.current_step == model_input.current_step
    assert (received_model_input.last_sampled_token_ids ==
            model_input.last_sampled_token_ids).all()
    assert received_model_input.is_multi_step == model_input.is_multi_step
