import torch
import math
from typing import Optional

import tidal._kernels as _kernels
from tidal.utils.utils import TensorLayout
from tidal.utils.kv_cache import KvCache
from tidal.utils.controller import InferenceController
from tidal.utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper

__all__ = [
    "TensorLayout",
    "KvCache",
    "InferenceController",
    "BatchDecodeWithPagedKVCacheWrapper",
    "append_kv",
    "prefill_forward",
    "decode_topk",
    "decode_sparse_attn",
    "rms_norm_forward",
    "apply_rope_in_place",
    "apply_llama31_rope_in_place",
]


def apply_rope_in_place(
    q: torch.Tensor,
    k: torch.Tensor,
    past_kv_len: int,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
):
    """
    Semantics of `apply_rope_in_place`:
    Apply RoPE (Relative Positional Encoding) in-place.
    On q, k which is generated by GEMM. Layout is naturally NHD.

    Args:
        q: Shape: `[N, H, D]`.
        k: Shape: `[N, H, D]`.
        past_kv_len: Length of past KV cache. Used to calculate frequency.
    """
    if rope_scale is None:
        rope_scale = 1.0
    if rope_theta is None:
        rope_theta = 1e4
    _kernels.apply_rope_in_place(
        q,
        k,
        past_kv_len,
        rope_scale,
        rope_theta,
    )

def apply_llama31_rope_in_place(
    q: torch.Tensor,
    k: torch.Tensor,
    past_kv_len: int,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
    low_freq_factor: Optional[float] = None,
    high_freq_factor: Optional[float] = None,
    old_context_length: Optional[float] = None,
):
    """
    Semantics of `apply_rope_in_place`:
    Apply RoPE (Relative Positional Encoding) in-place.
    On q, k which is generated by GEMM. Layout is naturally NHD.

    Args:
        q: Shape: `[N, H, D]`.
        k: Shape: `[N, H, D]`.
        past_kv_len: Length of past KV cache. Used to calculate frequency.
    """
    if rope_scale is None:
        rope_scale = 1.0
    if rope_theta is None:
        rope_theta = 1e4
    _kernels.apply_llama31_rope_in_place(
        q,
        k,
        past_kv_len,
        rope_scale,
        rope_theta,
        low_freq_factor,
        high_freq_factor,
        old_context_length
    )


def rms_norm_forward(
    input: torch.Tensor,
    weight: torch.Tensor,
    epsilon: float,
) -> torch.Tensor:
    o = torch.empty_like(input, dtype=input.dtype, device=input.device)
    f = _kernels.rms_norm_forward
    f(
        input,
        weight,
        o,
        epsilon,
    )
    return o


def append_kv(
    k: torch.Tensor,
    v: torch.Tensor,
    iController: InferenceController,
    layer_idx: int,
):
    """
    Semantics of `append_kv`:
    Append new generated k/v into kv cache and meta data cache.
    Automatically dispatch to Prefill / Decode Kernel

    Notations for shapes:
    `B`: batch size
    `N`: number of heads
    `D`: head dimension
    `L`: number of layers
    `MAXLEN`: maximum length of the KV cache

    Args:
        k: Shape: `[B, N, D]`. Key projection (`X @ W_k`).
        v: Shape: `[B, N, D]`. Value projection (`X @ W_v`).
        iController: InferenceController object, which contains all needed information.
        layer_idx: Layer index of the KV cache.
    """
    seq_len = k.size(0)
    if seq_len > 1:
        _kernels.append_kv_cache_prefill(
            k,
            v,
            iController.kv_cache.buf_layer(layer_idx),
            iController.kv_indices_with_last,
            iController.kv_indptr_for_append,
            iController.kv_cache.last_page_len,
            iController.kv_last_page_idx,
            iController.layout,
        )
    else:
        _kernels.append_kv_cache_decode(
            k,
            v,
            iController.kv_cache.buf_layer(layer_idx),
            iController.kv_indices_with_last,
            iController.kv_indptr_for_append,
            iController.kv_cache.last_page_len,
            iController.kv_last_page_idx,
            iController.layout,
        )


def prefill_forward(
    q: torch.Tensor,
    iController: InferenceController,
    layer_idx: int,
) -> torch.Tensor:
    """
    Semantics of `prefill_forward`:
    New genrated K/Vs are already in the kv cache and meta data cache (well-maintained).
    Perform FlashInfer Self-Attention with Casual Attention.
    Note that we not have position shift and current version not support Prefill Optimization.

    Notations for shapes:
    `B`: batch size
    `N`: number of heads
    `D`: head dimension
    `L`: number of layers
    `MAXLEN`: maximum length of the KV cache

    Args:
        q: Shape: `[B, N, D]`. Key projection (`X @ W_k`).
        iController: InferenceController object, which contains all needed information.
        layer_idx: Layer index of the KV cache.
    """

    f = _kernels.prefill_with_paged_kv_cache
    o = f(
        q,
        iController.kv_cache.buf_layer(layer_idx),
        iController.kv_indices_with_last,
        iController.kv_cache.last_page_len,
        True,  # Casual
        iController.layout,
        False  # FP16 Accumulator for 4090
    )
    return o


def decode_topk(
    iController: InferenceController,
):
    """
    Semantics of `decode_topk`:
    select top-k pages with highest attention score.

    Notations for shapes:
    `B`: batch size
    `N`: number of heads
    `D`: head dimension
    `L`: number of layers
    `MAXLEN`: maximum length of the KV cache

    Args:
        q: Shape: `[B, N, D]`. Key projection (`X @ W_k`).
        iController: InferenceController object, which contains all needed information.
        layer_idx: Layer index of the KV cache.
    """
    # excluding the last page
    f = _kernels.topk_filtering
    f(
        iController.qk_product,
        iController.kv_indices_with_last_decode,
        iController.topk_dout_buffer,
        iController.topk_dindices_buffer,
        iController.topk_buf,
        iController.inference_token_budget,
    )


def decode_sparse_attn(
    q: torch.Tensor,
    iController: InferenceController,
    layer_idx: int,
    topk_indices: torch.Tensor,
    output_qk_product: bool,
    rope_scale: Optional[float] = None,
    rope_theta: Optional[float] = None,
) -> torch.Tensor:
    """
    Semantics of `decode_sparse_attn`:
    Excute self-attention only on the selected pages (Top-k output)

    Notations for shapes:
    `B`: batch size
    `N`: number of heads
    `D`: head dimension
    `L`: number of layers
    `MAXLEN`: maximum length of the KV cache

    Args:
        q: Shape: `[B, N, D]`. Key projection (`X @ W_k`).
        iController: InferenceController object, which contains all needed information.
        layer_idx: Layer index of the KV cache.
        topk_indices: Shape: `[N, page_budget-1]`. Top-k indices.
    """
    o = torch.empty_like(q, dtype=q.dtype, device=q.device)
    # print(layer_idx, topk_indices, iController.kv_indptr_for_approx_decode, iController.kv_last_page_idx)
    iController._decode_handler.forward(
        q,
        o,
        iController.kv_cache.buf_layer(layer_idx),
        topk_indices,
        iController.kv_indptr_for_approx_decode,
        iController.kv_cache.last_page_len,
        iController.kv_last_page_idx,
        iController.qk_product if output_qk_product else None,
        rope_scale,
        rope_theta,
    )
    return o
