# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

import vllm._custom_ops as ops
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend,
                                              AttentionMetadataBuilder,
                                              AttentionType,
                                              is_quantized_kv_cache)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder


class CPUMLABackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "CPU_MLA"

    @staticmethod
    def get_metadata_cls() -> Type["CPUMLAMetadata"]:
        return CPUMLAMetadata

    @staticmethod
    def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
        return CPUMLAMetadataBuilder

    @staticmethod
    def get_state_cls() -> Type["MLACommonState"]:
        return MLACommonState

    @staticmethod
    def get_impl_cls() -> Type["CPUMLAImpl"]:
        return CPUMLAImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, block_size, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
    ) -> None:
        ops.copy_blocks_mla(kv_caches, src_to_dists)

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [576]


@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
    # New for MLA
    # Input positions for rotrary embeddings since for MLA the rotary
    # position embeddings are applied inside the attention backend
    input_positions: torch.Tensor = None

    # required by MLACommonImpl
    is_profile_run: bool = False


class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):

    def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
        self.chunked_prefill = input_builder.chunked_prefill
        self.input_builder = input_builder
        assert not self.chunked_prefill, \
            "chunked prefill is currently not supported"

    def prepare(self):
        self.input_data = self.input_builder.input_data

    def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
        input_data = self.input_data
        prefill_seq_lens = seq_lens[0:input_data.num_prefills]
        prefill_query_lens = query_lens[0:input_data.num_prefills]
        slot_mapping = torch.tensor(input_data.slot_mapping,
                                    dtype=torch.long,
                                    device="cpu")

        # metadata for prefill
        if input_data.num_prefills > 0:
            query_lens_tensor = torch.tensor(prefill_query_lens,
                                             dtype=torch.int32,
                                             device="cpu")
            kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                          dtype=torch.int32,
                                          device="cpu")
            query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                          dtype=torch.int32,
                                          device="cpu")
            kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                       dtype=torch.int32,
                                       device="cpu")
            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=query_start_loc[1:])
            torch.cumsum(kv_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=kv_start_loc[1:])
            max_query_len = max(prefill_query_lens)
            max_kv_len = max(prefill_seq_lens)

            # for chunked-prefill
            if self.chunked_prefill:
                prefill_block_tables = make_tensor_with_pad(
                    self.input_data.prefill_block_tables,
                    pad=0,
                    dtype=torch.int32,
                    device="cpu",
                )
            else:
                prefill_block_tables = None

        else:
            query_start_loc = None
            kv_start_loc = None
            max_query_len = None
            max_kv_len = None
            prefill_block_tables = None

        # metadata for decode
        if input_data.num_decode_tokens != 0:
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[input_data.num_prefills:],
                dtype=torch.int32,
                device="cpu",
            )
            block_tables = make_tensor_with_pad(
                self.input_data.decode_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
        else:
            block_tables = torch.tensor([])
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[:input_data.num_prefills],
                dtype=torch.int32,
                device="cpu",
            )

        # For multi-modal models
        placeholder_index_maps = None
        if len(input_data.multi_modal_inputs_list) != 0:
            placeholder_index_maps = {
                modality: placeholder_map.index_map()
                for modality, placeholder_map in
                input_data.multi_modal_placeholder_maps.items()
            }

        return CPUMLAMetadata(
            chunked_prefill=self.chunked_prefill,
            seq_lens=prefill_seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_kv_len=max_kv_len,
            query_start_loc=query_start_loc,
            kv_start_loc=kv_start_loc,
            max_decode_seq_len=input_data.max_decode_seq_len,
            num_prefills=input_data.num_prefills,
            num_prefill_tokens=input_data.num_prefill_tokens,
            num_decode_tokens=input_data.num_decode_tokens,
            block_tables=block_tables,
            prefill_block_tables=prefill_block_tables,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=placeholder_index_maps,
            enable_kv_scales_calculation=False,
            input_positions=torch.tensor([self.input_data.input_positions]))


class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[List[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            blocksparse_params: Optional[Dict[str, Any]],
            logits_soft_cap: Optional[float],
            attn_type: str,
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         **mla_args)

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "CPUMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "CPUMLAImpl")

        # states is implemented.
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "CPUMLAImpl with FP8 KV cache not yet supported")

    def _forward_prefill(
            self,
            q: torch.Tensor,
            kv_c_normed: torch.Tensor,
            k_pe: torch.Tensor,
            kv_c_and_k_pe_cache: torch.Tensor,
            attn_metadata: CPUMLAMetadata,  # type: ignore[override]
    ) -> torch.Tensor:

        prefill_metadata = attn_metadata.prefill_metadata
        assert prefill_metadata is not None

        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

        # For MLA the v head dim is smaller than qk head dim so we pad out
        # v with 0s to match the qk head dim
        v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
                                           value=0)

        output = torch.empty_like(q)
        ipex_ops.varlen_attention(
            query=q,
            key=k,
            value=v_padded,
            out=output,
            seqlen_q=prefill_metadata.query_start_loc,
            seqlen_k=prefill_metadata.query_start_loc,
            max_seqlen_q=prefill_metadata.max_query_len,
            max_seqlen_k=prefill_metadata.max_query_len,
            pdropout=0.0,
            softmax_scale=self.scale,
            zero_tensors=False,
            is_causal=True,
            return_softmax=False,
            gen_=None,
            logits_soft_cap=0.0,
        )

        # remove padding
        output = output.view(-1, self.num_heads,
                             q.shape[-1])[..., :v.shape[-1]]
        output = output.reshape(-1, self.num_heads * v.shape[-1])
        return self.o_proj(output)[0]

    def _forward_decode(
            self,
            q_nope: torch.Tensor,
            q_pe: torch.Tensor,
            kv_c_and_k_pe_cache: torch.Tensor,
            attn_metadata: CPUMLAMetadata,  # type: ignore[override]
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0

        decode_meta = attn_metadata.decode_metadata
        assert decode_meta is not None

        q = torch.cat([q_nope, q_pe], dim=-1)
        o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)

        # Run MQA
        ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
                                   decode_meta.block_tables,
                                   decode_meta.seq_lens_tensor)
        return self._v_up_proj_and_o_proj(o)
