"""Kernel test utils"""

import itertools
import random
import unittest
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
                    Type, Union)

import pytest
import torch
from torch._prims_common import TensorLikeType

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
                        STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
)

ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
    "test_schema",
    "test_autograd_registration",
    "test_faketensor",
    "test_aot_dispatch_dynamic",
)


class QKVInputs(NamedTuple):
    '''
    Data structure for representing unpacked attention inputs, 
    query/key/values and their sequence lengths.

    Attributes:

        * {query,key,value}: unpacked (batch_size x padded_seq_len x 
                             num_heads x head_size) attention inputs
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
    '''

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
    q_seq_lens: List[int]
    kv_seq_lens: List[int]


class QKVO(NamedTuple):
    '''
    Data structure for representing unpacked attention inputs, 
    alongside unpacked known-correct attention output

    Attributes:

        * qkv: unpacked (batch_size x padded_seq_len x 
                             num_heads x head_size) attention inputs
        * ideal_output: unpacked (batch_size x padded_seq_len x 
                        num_heads x head_size) known-correct attention output
    '''

    qkv: QKVInputs
    ideal_output: torch.Tensor


class PackedQKVInputs(NamedTuple):
    '''
    Data structure for representing packed attention inputs

    Attributes:

        * {query,key,value}: packed (number_of_tokens x num_heads 
                             x head_size) attention inputs
        * q_start_loc_list: list of query start locations within packed tensor
        * kv_start_loc_list: shared list of key/value start locations within
                             packed tensor
        * q_seq_lens: query sequence lengths list
        * kv_seq_lens: shared key/value sequence lengths list
    '''

    query: torch.Tensor
    key: torch.Tensor
    value: torch.Tensor
    q_start_loc_list: Optional[List[int]]
    kv_start_loc_list: Optional[List[int]]
    q_seq_lens: Optional[List[int]]
    kv_seq_lens: Optional[List[int]]


class PackedQKVO(NamedTuple):
    '''
    Data structure for representing packed attention inputs, 
    alongside packed known-correct attention output

    Attributes:

        * packed_qkv: packed (number_of_tokens x num_heads 
                      x head_size) attention inputs
        * ideal_output: packed (number_of_tokens x num_heads 
                        x head_size) known-correct attention output
    '''

    packed_qkv: Optional[PackedQKVInputs]
    ideal_output: torch.Tensor


class KVMemoryMap(NamedTuple):
    '''
    Data structure for encapsulating KV cache memory mapping.

    Attributes:

        * block_tables: KV cache block tables
        * slot_mapping: mapping of sequence offset to physical address
    '''

    block_tables: torch.Tensor
    slot_mapping: torch.Tensor


class PhaseTestParameters(NamedTuple):
    '''
    Data structure for encapsulating the test parameters
    for a given test "phase" (prefill or decode phase) and attention
    scenario (encoder, decoder-self, encoder/decoder-cross)

    Attributes:

        * packed_qkvo: packed (number_of_tokens x num_heads 
                       x head_size) attention inputs & known-correct
                       output
        * kv_mmap: KV cache memory mapping, specific to this test phase &
                   attention scenario
    '''

    packed_qkvo: PackedQKVO
    kv_mmap: Optional[KVMemoryMap]


def maybe_make_int_tensor(
    _list: Optional[List[int]],
    device: Union[torch.device, str],
) -> torch.Tensor:
    '''
    Convert Python int list to a 1D int torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D int torch.Tensor on `device`
    * None otherwise
    '''
    return None if _list is None else torch.tensor(
        _list, dtype=torch.int, device=device)


def maybe_make_long_tensor(
    _list: Optional[List[int]],
    device: Union[torch.device, str],
) -> torch.Tensor:
    '''
    Convert Python int list to a 1D long torch.Tensor on `device`

    Returns:

    * If _list is not None: 1D long torch.Tensor on `device`
    * None otherwise
    '''
    return None if _list is None else torch.tensor(
        _list, dtype=torch.long, device=device)


def maybe_max(_list: Optional[List]) -> Optional[Number]:
    '''
    Returns:

    * If _list is not None: max(_list)
    * None otherwise
    '''
    return None if _list is None else max(_list)


def make_causal_mask(
    q_max_seq_len: int,
    kv_max_seq_len: int,
) -> torch.Tensor:
    '''
    Create a q_max_seq_len x kv_max_seq_len causal mask

    Arguments:
    
    * q_max_seq_len: query max seq len
    * kv_max_seq_len: key/value max seq len

    Returns:

    * 2D tensor, q_max_seq_len x kv_max_seq_len
    '''

    # Create a matrix where entry (i, j) is True if i >= j
    mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
    # Replace True with float('-inf') and False with 0
    mask = mask.masked_fill(mask == 1,
                            float('-inf')).masked_fill(mask == 0, 0.0)
    return mask


def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
                                  backend_name: str) -> None:
    '''
    Override the environment variable indicating the vLLM backend temporarily,
    using pytest monkeypatch to ensure that the env vars get
    reset once the test context exits.

    Arguments:

    * mpatch: pytest monkeypatch instance
    * backend_name: attention backend name to force
    '''
    mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)


def ref_masked_attention(query: torch.Tensor,
                         key: torch.Tensor,
                         value: torch.Tensor,
                         scale: float,
                         custom_mask: Optional[torch.Tensor] = None,
                         q_seq_lens: Optional[List] = None,
                         kv_seq_lens: Optional[List] = None) -> torch.Tensor:
    '''
    "Golden" masked attention reference. Supports two types of masking:

    * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
      padding elements
    * Custom attention mask, which can force an arbitrary mask tensor, i.e.
      causal

    Arguments:

    * query: batch_size x q_padded_seq_len x num_heads x head_size
    * key: batch_size x kv_padded_seq_len x num_heads x head_size
    * value: batch_size x kv_padded_seq_len x num_heads x head_size
    * scale: Attention scale factor
    * custom_mask: custom attention mask; good place to inject a causal
      attention mask
    * q_seq_lens: list of unpadded query seq_lens for each batch index
    * kv_seq_lens: list of unpadded key/value seq_lens for each batch index

    Returns:

    * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
    '''

    assert q_seq_lens is not None
    assert kv_seq_lens is not None

    batch_size = query.shape[0]
    assert (len(q_seq_lens) == batch_size)
    assert (len(kv_seq_lens) == batch_size)

    attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()

    # Basic attention mask, derived from seq lens
    if (q_seq_lens is not None) or (kv_seq_lens is not None):
        attn_mask = torch.zeros_like(attn_weights)
        if q_seq_lens is not None:
            for bdx, plen in enumerate(q_seq_lens):
                attn_mask[bdx, :, plen:, :] = -torch.inf
        if kv_seq_lens is not None:
            for bdx, plen in enumerate(kv_seq_lens):
                attn_mask[bdx, :, :, plen:] = -torch.inf

        attn_weights = attn_weights + attn_mask.float()

    # Custom attention mask
    if custom_mask is not None:
        attn_weights = attn_weights + custom_mask.float()

    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
    return out


def make_qkv(
    batch_size: int,
    max_q_seq_len: int,
    max_kv_seq_len: Optional[int],
    num_heads: int,
    head_size: int,
    device: Union[torch.device, str],
    force_kv_seq_lens: Optional[List[int]] = None,
    attn_type: AttentionType = AttentionType.ENCODER_DECODER,
    force_max_len: bool = False,
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
    '''
    Construct QKV test tensors for self- and cross-attention.

    Generates three query/key/value triplets:

    * "Baseline" query/key/value (for input to reference attention function)
    * "Prefill" query/key/value (last sequence offset zero'd out, for use as
      input to prefill kernel)
    * "Decode" query/key/value (only the last sequence offset  from baseline,
      for use as input to decode kernel)

    Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
    seqlens

    Arguments:

    * batch_size
    * max_q_seq_len: max query seq len
    * max_kv_seq_len: max key/value seq len
    * num_heads
    * head_size
    * is_encoder_decoder_attn: if True, query seqlen may differ from 
      key/value seqlen (as is often the case for cross-attention); 
      o/w, query/key/value seqlens match at each batch index 
      (max_kv_seq_len is unused)
    * force_kv_seq_lens: if not None, overrides kv sequence lengths
    * attn_type: encoder, decoder self, or enc/dec cross attention
    * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
      seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
      and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
    * device: CPU or CUDA device

    Returns:

    * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
    * Prefill QKVInputs structure (containing all but the last sequence offset)
    * Decode QKVInputs structure (containing all only the last sequence offset)
    '''

    if force_max_len:
        q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
    else:
        q_seq_lens = [
            random.randint(2, max_q_seq_len) for _ in range(batch_size)
        ]
    kv_seq_lens = None
    if force_kv_seq_lens is not None:
        kv_seq_lens = force_kv_seq_lens
    elif attn_type != AttentionType.ENCODER_DECODER:
        # K,V seq lens match Q for self-attention
        kv_seq_lens = q_seq_lens
    else:
        # K,V seq lens are distinct from Q seq lens & random
        assert max_kv_seq_len is not None
        if force_max_len:
            kv_seq_lens = [max_kv_seq_len] * batch_size
        else:
            kv_seq_lens = [
                random.randint(2, max_kv_seq_len) for _ in range(batch_size)
            ]

    query = torch.rand(
        (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    key = torch.rand(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    value = torch.rand(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    prefill_query = torch.zeros(
        (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
    prefill_key = torch.zeros(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
    prefill_value = torch.zeros(
        (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)

    decode_query = torch.zeros(
        (batch_size, 1, num_heads, head_size)).to(device)
    decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
    decode_value = torch.zeros(
        (batch_size, 1, num_heads, head_size)).to(device)

    for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
                                                      kv_seq_lens)):
        query[bdx, q_seq_len:, :, :] = 0
        key[bdx, kv_seq_len:, :, :] = 0
        value[bdx, kv_seq_len:, :, :] = 0

        prefill_query[bdx,
                      0:(q_seq_len - 1), :, :] = query[bdx,
                                                       0:(q_seq_len - 1), :, :]
        prefill_key[bdx,
                    0:(kv_seq_len - 1), :, :] = key[bdx,
                                                    0:(kv_seq_len - 1), :, :]
        prefill_value[bdx, 0:(kv_seq_len -
                              1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]

        decode_query[bdx, :, :, :] = query[bdx,
                                           (q_seq_len - 1):q_seq_len, :, :]
        decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
        decode_value[bdx, :, :, :] = value[bdx,
                                           (kv_seq_len - 1):kv_seq_len, :, :]

    prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
    prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]

    decode_q_seq_lens = [1 for _ in q_seq_lens]
    decode_kv_seq_lens = [1 for _ in kv_seq_lens]

    return (
        QKVInputs(
            query,  # Overall QKV inputs
            key,
            value,
            q_seq_lens,
            kv_seq_lens),
        QKVInputs(
            prefill_query,  # Prefill subset of QKV sequences
            prefill_key,
            prefill_value,
            prefill_q_seq_lens,
            prefill_kv_seq_lens),
        QKVInputs(
            decode_query,  # Decode subset of KV sequences
            decode_key,
            decode_value,
            decode_q_seq_lens,
            decode_kv_seq_lens))


def pack_tensor(
        unpacked_tensor: torch.Tensor, seq_lens: List[int],
        device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
    '''
    Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
    unpadded number_of_tokens x num_heads x head_size tensor, where
    number_of_tokens = sum(seq_lens)

    Arguments:

    * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
    * seq_lens: list of token counts for each seq
    * device: CPU or CUDA device

    Returns

    * packed_tensor: number_of_tokens x num_heads x head_size
    * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
      list(itertools.accumulate(seq_lens))
    '''

    num_tok = sum(seq_lens)
    num_heads = unpacked_tensor.shape[-2]
    head_size = unpacked_tensor.shape[-1]
    start_loc_list = [0] + list(itertools.accumulate(seq_lens))
    packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)

    for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):

        packed_tensor[start_loc:(
            start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]

    return packed_tensor, start_loc_list


def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
                                           str]) -> PackedQKVInputs:
    '''
    Individually pack each of Q, K and V, each with dimensions batch_size x
    padded_seq_len x num_heads x head_size, into respective number_of_tokens x
    num_heads x head_size tensors.
    
    For Q, number_of_tokens = sum(q_seq_lens).

    For K and V, number_of_tokens = sum(kv_seq_lens)

    Arguments:

    * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
           attention inputs
    * device: CPU or CUDA device

    Returns

    * Packed (number_of_tokens x num_heads x head_size) QKV inputs
      derived from unpacked inputs
    '''

    if qkv.query is None:
        packed_query = None
        q_start_loc_list = None
    else:
        packed_query, q_start_loc_list = pack_tensor(qkv.query,
                                                     qkv.q_seq_lens,
                                                     device=device)
    packed_key, kv_start_loc_list = pack_tensor(qkv.key,
                                                qkv.kv_seq_lens,
                                                device=device)
    packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
    return PackedQKVInputs(
        packed_query, packed_key, packed_value, q_start_loc_list,
        kv_start_loc_list,
        (None if q_start_loc_list is None else qkv.q_seq_lens),
        qkv.kv_seq_lens)


def make_backend(backend_name: str) -> AttentionBackend:
    '''
    Construct the backend instance determined by the backend_name string
    argument.

    "XFORMERS" -> construct xformers backend

    TODO: other backends

    Note: at time of writing the Attention wrapper automatically selects
    its own backend for Attention.forward(); so the backend instance which
    you generate with this function is not meant to be used for *running*
    inference, but rather for generating compatible metadata structures
    using backend.make_metadata()


    Returns:

    * Backend instance
    '''
    if backend_name == STR_XFORMERS_ATTN_VAL:
        # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
        from vllm.attention.backends.xformers import XFormersBackend
        return XFormersBackend()
    elif backend_name == STR_FLASH_ATTN_VAL:
        from vllm.attention.backends.flash_attn import FlashAttentionBackend
        return FlashAttentionBackend()

    raise AssertionError(
        f"Unrecognized backend_name {backend_name} for unit test")


def _make_metadata_tensors(
    seq_lens: Optional[List[int]],
    context_lens: Optional[List[int]],
    encoder_seq_lens: Optional[List[int]],
    device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
           torch.Tensor, torch.Tensor, Optional[int]]:
    '''
    Build scalar & tensor values required to build attention metadata structure.

    Arguments:

    * seq_lens: list of token-counts for each decoder input seq
    * context_lens: list of context length values for each seq
    * encoder_seq_lens: list of token-counts for each encoder input seq
    * device: CPU or CUDA device

    Returns:

    * seq_lens_tensor: decoder seq_lens list, as tensor
    * context_lens_tensor: context_lens list, as tensor
    * max_context_len: max(context_lens)
    * max_seq_len: max(seq_lens)
    * seq_start_loc: start idx of each sequence
    * encoder_seq_lens_tensor: encoder seq_lens list, as tensor
    * encoder_seq_start_loc: start idx of each encoder sequence
    * max_encoder_seq_len: encoder seq_lens list, as tensor
    '''
    seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
    context_lens_tensor = maybe_make_int_tensor(context_lens, device)
    max_context_len = maybe_max(context_lens)
    max_seq_len = maybe_max(seq_lens)

    encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
    max_encoder_seq_len = (None if encoder_seq_lens is None else
                           max(encoder_seq_lens))

    seq_start_loc = None

    if seq_lens_tensor is not None:
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=seq_lens_tensor.device)
        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

    encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
                                        dtype=torch.int32,
                                        device=encoder_seq_lens_tensor.device)
    torch.cumsum(encoder_seq_lens_tensor,
                 dim=0,
                 dtype=encoder_seq_start_loc.dtype,
                 out=encoder_seq_start_loc[1:])

    return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
            seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
            max_encoder_seq_len)


def make_kv_cache(num_blocks: int,
                  num_heads: int,
                  head_size: int,
                  block_size: int,
                  device: Union[torch.device, str],
                  backend: str,
                  default_val: float = 0.0) -> torch.Tensor:
    '''
    Create a fake KV cache.

    Arguments:

    * num_blocks: number of blocks in the KV cache
    * num_heads: number of attention heads
    * head_size: head dimension
    * block_size: number of offsets within a block
    * device: CPU or CUDA device
    * default_val: initialization value for KV cache elements

    Returns:

    * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
    *     for backend 'XFORMERS' 
    * kv_cache: 2 x num_blocks x block_size x num_heads x head_size
    *     for backend 'FLASH_ATTN'  
    '''
    if backend == 'XFORMERS':
        kv_cache = torch.rand(
            (2, num_blocks, block_size * num_heads * head_size)).to(device)
    elif backend == 'FLASH_ATTN':
        kv_cache = torch.rand(
            (2, num_blocks, block_size, num_heads, head_size)).to(device)
    else:
        raise ValueError(
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
            f"'FLASH_ATTN'.")
    if default_val is not None:
        kv_cache[:, :, :] = default_val
    return kv_cache


def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
    '''
    Compute the minimum number of blocks required to hold num_tokens tokens,
    given block_size
    '''
    return (num_tokens + block_size) // block_size


def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
    return maybe_make_long_tensor([], device)


def make_empty_block_tables_tensor(device: Union[torch.device, str]):
    return torch.tensor([], device=device)


def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
                       device: Union[torch.device, str]):
    '''
    Split a slot mapping into valid prefill- and decode-phase slot mappings.

    Context:
    * Your goal is to test (1) prefill of N prompts, with prompt-lengths
      {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
      for all N prompts (N tokens total); the resultant sequence lengths 
      after decode would be {K_i + 1 for i \\in [0,N)}
    * The test you want to do requires (1) having the prefill slot mapping 
      for all tokens present during prefill, the number of which is 
      M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N 
      decoded tokens
    
    This function consumes a single 1D slot mapping, which is the 
    concatenation of N slot mappings each of length K_i + 1 (corresponding
    to the  sequence lengths after decode), with a total length of
    P = \\sum_i{K_i + 1} = M + N

    The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
    from each of the N subsequences in the slot mapping (i.e. omitting the 
    decoded token's mapping.)

    The N excised entries are appended to obtain the decode-phase slot mapping

    Arguments:

    * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
      post-decode sequences
    * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the 
      description above)
    * device: cuda, cpu, etc.

    Returns:

    * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor) 
      reflecting all N prefill prompts
    * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting 
      all N decoded tokens
    '''

    prefill_slot_mapping = []
    decode_slot_mapping = []

    base_idx = 0
    for seq_len in seq_lens:
        prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
                                                                seq_len - 1)])
        decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

    return (maybe_make_long_tensor(prefill_slot_mapping, device),
            maybe_make_long_tensor(decode_slot_mapping, device))


def make_block_tables_slot_mapping(
        block_size: int,
        seq_lens: List[int],
        device: Union[torch.device, str],
        block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
    '''
    Construct fake block tables & slot mappings.

    For a sequence with num_tokens tokens the minimum number
    of required KV cache blocks is

    num_blocks = (num_tokens + block_size) // block_size

    Then the minimum KV cache size in blocks is

    total_cache_blocks = sum(num_blocks for all seqs) 

    Then, the blocktable mapping counts downward from

    block_base_addr + total_cache_blocks

    to

    block_base_addr
    

    The constructed block-tables and slot-mapping are sized to the
    lengths of the sequences in their entirety (as reflected by seq_lens),
    i.e. the total of prefill prompt tokens + decoded tokens.

    Arguments:

    * block_size: number of offsets per block
    * seq_lens: list of token-counts for each sequence
    * block_base_addr: the block table base address
    * device: CPU or CUDA device

    Return:

    * block_tables_tensor: block table for sequence   
    * slot_mapping_list: slot mapping for sequence
    * max_block_idx: the highest block address within this block table
    '''

    # Provision minimum number of KV cache blocks
    num_blocks_list = [
        _num_tokens_to_min_blocks(num_tokens, block_size)
        for num_tokens in seq_lens
    ]
    max_block_table_len = max(num_blocks_list)
    block_table_pad_tokens = 10

    block_tables = []
    slot_mapping_list = []
    # Compute uppermost address of block table
    total_cache_blocks = sum(num_blocks_list)
    block_base_idx = block_base_addr + total_cache_blocks
    max_block_idx = block_base_idx
    for sdx, num_tokens in enumerate(seq_lens):
        num_blocks = num_blocks_list[sdx]
        block_table = list(
            range(block_base_idx, block_base_idx - num_blocks, -1))
        for idx in range(num_tokens):
            mapping_value = (
                idx % block_size) + block_table[idx // block_size] * block_size
            slot_mapping_list.append(mapping_value)

        block_base_idx -= num_blocks
        block_tables.append(block_table)

    block_tables_tensor = make_tensor_with_pad(
        block_tables,
        max_len=max_block_table_len + block_table_pad_tokens,
        pad=0,
        dtype=torch.int,
        device=device,
    )

    return (block_tables_tensor, slot_mapping_list, max_block_idx)


def make_test_metadata(
    attn_backend: _Backend,
    is_prompt: bool,
    seq_lens: Optional[List[int]],
    decoder_test_params: Optional[PhaseTestParameters],
    device: Union[torch.device, str],
    encoder_test_params: Optional[PhaseTestParameters] = None,
    cross_test_params: Optional[PhaseTestParameters] = None
) -> AttentionMetadata:
    '''
    Construct fake attention metadata for a given test phase
    (prefill-phase or decode-phase).

    encoder_test_params and cross_test_params arguments allow encoder
    attention and enc/dec cross-attention (respectively) to use distinct
    metadata values from decoder self-attention (decoder_test_params.)
    
    if encoder_test_params and cross_test_params are None, the attention
    metadata will support decoder-only scenario.

    Assumptions:

    * No chunked prefill -> a batch is 100% prefill or 100% decode, never both

    Arguments:

    * attn_backend_name: Backend for sourcing attention kernels
    * is_prompt: prefill if True, o/w decode
    * seq_lens: list of token counts for each sequence
    * decoder_test_params: decoder self-attention test params; 
                           this function requires
                           kv_mmap (memory mapping) field
    * device: CPU or CUDA device
    * encoder_test_params: encoder attention test params;
                           this function requires encoder query
                           sequence lengths field. If None,
                           encoder query sequence lengths are
                           treated as None
    * cross_test_params: enc/dec cross-attention test params;
                         this function requires kv_mmap field.
                         If None, KV cache memory map data
                         structures are treated as None

    Return:

    * AttentionMetadata structure
    '''

    # Decoder self-attention memory mapping
    # decoder_test_params is None signals encoder-only
    # scenario, so kv_mmap is None
    kv_mmap = (None
               if decoder_test_params is None else decoder_test_params.kv_mmap)

    # This function constructs metadata assuming no chunked prefill,
    # i.e. 100% prefill tokens or 100% decode tokens
    #
    # - If is_prompt, num_prefills_or_decodes is the number of prefills
    #   and num_prefill_or_decode_tokens is the number of prefill tokens
    # - If not is_prompt, num_prefills_or_decodes is the number of decodes
    #   and num_prefill_or_decode_tokens is the number of decode tokens
    #
    # seq_lens is None signals encoder-only
    # scenario, in which case num_prefills_or_decodes and
    # num_prefill_or_decode_tokens are unused
    num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))

    num_prefill_or_decode_tokens = (None if seq_lens is None else (
        sum(seq_lens) if is_prompt else len(seq_lens)))

    # Seems for non-prefix-caching scenarios context_lens
    # is never needed
    context_lens = None

    if encoder_test_params is None:
        encoder_seq_lens = None
        num_encoder_tokens = None
    else:
        # Encoder/decoder or encoder-only models only:
        # * Extract encoder input sequence lengths
        assert encoder_test_params.packed_qkvo.packed_qkv is not None
        encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
        num_encoder_tokens = (None if encoder_seq_lens is None else
                              (sum(encoder_seq_lens)))

    if cross_test_params is None:
        cross_kv_mmap = None
    else:
        # Encoder/decoder or encoder-only models only:
        # * Extract *cross-attention* slot_mapping and block table
        #   (kv_mmap)
        cross_kv_mmap = cross_test_params.kv_mmap

    attn_backend_obj = make_backend(attn_backend.name)

    if is_prompt:
        # Prefill-phase scenario

        num_prefills = num_prefills_or_decodes
        num_prefill_tokens = num_prefill_or_decode_tokens
        num_decode_tokens = 0

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
            seq_start_loc,
            encoder_seq_lens_tensor,
            encoder_seq_start_loc,
            max_encoder_seq_len,
        ) = _make_metadata_tensors(seq_lens,
                                   context_lens,
                                   encoder_seq_lens,
                                   device=device)
        return attn_backend_obj.make_metadata(
            num_prefills=num_prefills,
            slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            seq_start_loc=seq_start_loc,
            max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
            max_decode_seq_len=0,
            context_lens_tensor=context_lens_tensor,
            block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
            encoder_seq_start_loc=encoder_seq_start_loc,
            max_encoder_seq_len=max_encoder_seq_len,
            cross_slot_mapping=(None if cross_kv_mmap is None else
                                cross_kv_mmap.slot_mapping),
            cross_block_tables=(None if cross_kv_mmap is None else
                                cross_kv_mmap.block_tables))

    else:  # not is_prompt
        # Decode-phase scenario

        assert kv_mmap is not None
        assert num_prefill_or_decode_tokens is not None
        assert seq_lens is not None

        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = num_prefill_or_decode_tokens

        (
            seq_lens_tensor,
            context_lens_tensor,
            _,
            _,
            seq_start_loc,
            encoder_seq_lens_tensor,
            encoder_seq_start_loc,
            max_encoder_seq_len,
        ) = _make_metadata_tensors(seq_lens,
                                   context_lens,
                                   encoder_seq_lens,
                                   device=device)

        return attn_backend_obj.make_metadata(
            num_prefills=num_prefills,
            slot_mapping=kv_mmap.slot_mapping,
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=True,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            seq_start_loc=seq_start_loc,
            max_prefill_seq_len=0,
            max_decode_seq_len=max(seq_lens),
            max_decode_query_len=1,
            context_lens_tensor=context_lens_tensor,
            block_tables=kv_mmap.block_tables,
            use_cuda_graph=False,
            num_encoder_tokens=num_encoder_tokens,
            encoder_seq_lens=encoder_seq_lens,
            encoder_seq_lens_tensor=encoder_seq_lens_tensor,
            encoder_seq_start_loc=encoder_seq_start_loc,
            max_encoder_seq_len=max_encoder_seq_len,
            cross_slot_mapping=(None if cross_kv_mmap is None else
                                cross_kv_mmap.slot_mapping),
            cross_block_tables=(None if cross_kv_mmap is None else
                                cross_kv_mmap.block_tables))


def assert_actual_matches_ideal(test_params: PhaseTestParameters,
                                output_under_test: torch.Tensor,
                                backend: str) -> None:
    '''
    Assert that observed output matches the ideal output
    contained in the test parameters data structure.

    Arguments:

    * test_params: Test parameters including packed ideal output
    * output_under_test: actually observed output value
    '''
    ideal_output = test_params.packed_qkvo.ideal_output
    if backend == 'XFORMERS':
        torch.testing.assert_close(ideal_output,
                                   output_under_test.view_as(ideal_output))

    elif backend == 'FLASH_ATTN':
        # For FlashAttention override the accuracy thresholds to non default
        # values since we notice a higher difference between the ideal and
        # actual output.
        torch.testing.assert_close(ideal_output,
                                   output_under_test.view_as(ideal_output),
                                   atol=0.01,
                                   rtol=0.016)
    else:
        raise ValueError(
            f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
            f"'FLASH_ATTN'.")


# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> bool:
    """
    Reference implementation of torch.allclose
    """
    torch._refs._check_close_args(name="torch.allclose",
                                  a=a,
                                  b=b,
                                  rtol=rtol,
                                  atol=atol)

    return bool(
        torch.all(
            torch.isclose(a.double(),
                          b.double(),
                          rtol=rtol,
                          atol=atol,
                          equal_nan=equal_nan)).item())


# Marlin MoE test utils


def stack_and_dev(tensors: List[torch.Tensor]):
    dev = tensors[0].device
    return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
    return torch.mean(torch.abs(output - output_ref)) / torch.mean(
        torch.abs(output_ref))


def torch_moe(a, w1, w2, score, topk):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    topk_weight, topk_ids = torch.topk(score, topk)
    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = SiluAndMul()(
                a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
    return (out.view(B, -1, w2.shape[1]) *
            topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
    B, D = a.shape
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
    score = torch.softmax(score, dim=-1, dtype=torch.float32)
    _, topk_ids = torch.topk(score, topk)
    topk_ids = topk_ids.view(-1)
    for i in range(w.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            out[mask] = a[mask] @ w[i].transpose(0, 1)
    return (out.view(B, -1, w.shape[1])).sum(dim=1)


# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
                      torch._library.custom_ops.CustomOpDef],
            args: Tuple[Any, ...],
            kwargs: Optional[Dict[str, Any]] = None,
            *,
            test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
            raise_exception: bool = True,
            cond: bool = True) -> Dict[str, str]:
    with unittest.mock.patch('torch.allclose', new=fp8_allclose):
        return torch.library.opcheck(
            op,
            args,
            kwargs,
            test_utils=test_utils,
            raise_exception=raise_exception) if cond else {}


# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def to_int8(tensor: torch.Tensor):
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def baseline_scaled_mm(a: torch.Tensor,
                       b: torch.Tensor,
                       scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: Type[torch.dtype],
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    # We treat N-dimensional group scaling as extended numpy-style broadcasting
    # in numpy simply stretches dimensions with an extent of 1 to match the
    # the target shape by repeating the data along that dimension (broadcasting)
    # , we extend these semantics to say if the extent of a dimension in the
    # source shape is not 1 and does not match the target shape we repeat each
    # element along that dimension src_shape[dim] // target_shape[dim] times
    # example if we have:
    #       a = [[1, 2], and target_shape = (2, 4)
    #            [3, 4]]
    # then we would expand a to:
    #       a = [[1, 1, 2, 2],
    #            [3, 3, 4, 4]]
    # NOTE this function this function does not explicitly broadcast dimensions
    # with an extent of 1, since this can be done implicitly by pytorch
    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
                t = t.unsqueeze(i + 1)\
                  .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
                  .flatten(i, i + 1)
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

    output = torch.mm((scale_a * a.to(dtype=torch.float32)),
                      (scale_b * b.to(dtype=torch.float32))).to(out_dtype)

    if bias is not None:
        output = output + bias

    return output
