import math
from typing import Any

import torch
import triton
import triton.language as tl
from triton.language.standard import _log2

from ..utils.set_device import SetDevice
from ..common import calc_dims
from .triton_argsort import argsort


@triton.jit
def fast_median(arr, n_iters):
    N = arr.shape[0]
    yL = tl.min(arr, axis=0)
    lowest = yL - 1.0
    yR = tl.max(arr, axis=0)
    S = tl.sum(arr, axis=0) + lowest

    fL = S - N*yL
    fR = N*yR - S
    gL = -N + 2.0
    gR = N - 2.0

    for _ in range(n_iters):
        cur = (fR - fL + yL*gL - yR*gR) / (gL - gR)

        f_cur = tl.abs(lowest - cur) + tl.sum(tl.abs(arr - cur), axis=0)
        g_cur = tl.where(lowest > cur, -1.0, 1.0) + tl.sum(tl.where(arr > cur, 1.0, -1.0), axis=0)

        if g_cur < 0:
            yL = cur
            fL = f_cur
            gL = g_cur
        else:
            yR = cur
            fR = f_cur
            gR = g_cur

    median = tl.max(tl.where(arr <= yR, arr, lowest), axis=0)
    return median


@triton.jit
def get_split_(x, idx):
    """
    Get the split of x according to idx
    :param x: (..., N)
    :param idx: (N,)
    """
    batch_shape: tl.constexpr = x.shape[:-1]
    N: tl.constexpr = x.shape[-1]
    x = tl.reshape(x, batch_shape + [N // 2, 2])
    a, b = tl.split(x)
    if idx & 1 == 0:
        return a, idx >> 1
    else:
        return b, idx >> 1


@triton.jit
def get_split(x, idx):
    batch_shape: tl.constexpr = x.shape[:-1]
    N: tl.constexpr = x.shape[-1]
    for i in tl.static_range(_log2(N)):
        x, idx = get_split_(x, idx)
    x = tl.reshape(x, batch_shape)
    return x


@triton.jit
def iteration(begin_block_indices, end_block_indices,
              q_splits,
              k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,
              n, h, i,
              top_k: tl.constexpr, head_dim: tl.constexpr,
              query_offset, q_block_offset, q_start_padding, q_len, k_len,
              block_size_q: tl.constexpr, block_size_k: tl.constexpr,
              D_BLOCK_SIZE: tl.constexpr,
              start_sink_tokens, end_sink_tokens):

    mid_index = (begin_block_indices + end_block_indices) // 2
    begin_block_indices_branch = tl.interleave(begin_block_indices, mid_index)
    end_block_indices_branch = tl.interleave(mid_index, end_block_indices)

    # Select representative block for each node
    rep_block_indices = begin_block_indices_branch

    q = tl.arange(0, block_size_q)
    k = tl.arange(0, block_size_k)

    k_indices = tl.reshape(
        start_sink_tokens + rep_block_indices[None, :] * block_size_k + k[:, None],
        (block_size_k * (top_k * 2),)
    )  # (block_size_k * T_BLOCK_SIZE,)
    q_indices = i * block_size_q + q
    q_positions = q_block_offset * block_size_q + q_indices

    attn_scores = tl.zeros((block_size_q, block_size_k * (top_k * 2)), dtype=tl.float32)

    # Avoid blocks with 0 length
    pos_block_mask = tl.reshape(
        tl.broadcast_to(
            (end_block_indices_branch - begin_block_indices_branch)[None, :] > 0,
            (block_size_k, top_k * 2)
        ), (block_size_k * (top_k * 2),)
    )

    for d_block_idx in range(0, head_dim // D_BLOCK_SIZE):
        d = d_block_idx * D_BLOCK_SIZE + tl.arange(0, D_BLOCK_SIZE)

        k_rep_blocks = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :] * k_stride_i
            + d[:, None] * k_stride_d,
            mask=(
                pos_block_mask[None, :] &
                (k_indices[None, :] < k_len)
            ),
            other=0.0
        )  # (D_BLOCK_SIZE, block_size_k * (top_k * 2))

        q_split = get_split(q_splits, d_block_idx)  # (block_size_q, D_BLOCK_SIZE)

        partial_attn_scores = tl.dot(q_split, k_rep_blocks)
        attn_scores += partial_attn_scores

    valid_mask = (
        # Mask out non-causal pairs
        (k_indices[None, :] <= q_positions[:, None] - end_sink_tokens) &
        # Prevent out-of-bound access
        (k_indices[None, :] < k_len) &
        (query_offset <= q_positions[:, None]) &
        (q_positions[:, None] < query_offset + q_len)
    )  # (block_size_q, block_size_k * (top_k * 2))
    attn_scores = tl.where(valid_mask, attn_scores, -1e9)

    # Maximum attention scores inside each block
    attn_scores = tl.reshape(attn_scores, (block_size_q * block_size_k, top_k * 2))
    attn_scores = tl.max(attn_scores, axis=0)  # (top_k * 2,)

    # Do not select nodes with length 0
    branch_lengths = end_block_indices_branch - begin_block_indices_branch
    attn_scores = tl.where(branch_lengths > 0, attn_scores, -1e9)  # (top_k * 2,)

    # Select top-k nodes
    packed_ranges = begin_block_indices_branch << 16 | end_block_indices_branch
    _, packed_ranges = argsort(attn_scores[None, :], packed_ranges[None, :], descending=True)
    begin_block_indices = packed_ranges >> 16
    end_block_indices = packed_ranges & 0xFFFF
    begin_block_indices, _ = tl.split(tl.trans(tl.reshape(begin_block_indices, (2, top_k))))
    end_block_indices, _ = tl.split(tl.trans(tl.reshape(end_block_indices, (2, top_k))))

    return begin_block_indices, end_block_indices


@triton.jit
def mask_gen_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_d,  # (bsz, num_heads, q_len, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        r_buffer, r_stride_n, r_stride_h, r_stride_i, r_stride_t,  # (bsz, num_heads, q_blocks, top_k)
        bsz, num_heads, q_len, k_len, head_dim: tl.constexpr, top_k: tl.constexpr, query_offset,
        block_size_q: tl.constexpr, block_size_k: tl.constexpr, D_BLOCK_SIZE: tl.constexpr, n_iterations,
        start_sink_tokens, end_sink_tokens) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2)
    i = tl.program_id(0)

    q_block_offset = query_offset // block_size_q
    q_start_padding = query_offset - q_block_offset * block_size_q
    q_blocks = tl.cdiv(q_len + query_offset, block_size_q) - q_block_offset
    q_end_padding = q_blocks * block_size_q - (q_start_padding + q_len)
    k_blocks = tl.cdiv(k_len, block_size_k)
    k_blocks = max(k_blocks, tl.cdiv((q_block_offset + q_blocks) * block_size_q, block_size_k))

    key_blocks_per_query = tl.cdiv(
        (q_block_offset + i + 1) * block_size_q - start_sink_tokens - end_sink_tokens,
        block_size_k
    )
    key_blocks_per_query = tl.maximum(key_blocks_per_query, 0)
    begin_block_indices = key_blocks_per_query * tl.arange(0, top_k) // top_k
    end_block_indices = key_blocks_per_query * tl.arange(1, top_k + 1) // top_k

    q = tl.arange(0, block_size_q)
    q_indices = i * block_size_q + q
    q_positions = q_block_offset * block_size_q + q_indices
    d_splits_indices = (
        tl.arange(0, head_dim // D_BLOCK_SIZE)[None, :] * D_BLOCK_SIZE
        + tl.arange(0, D_BLOCK_SIZE)[:, None]
    )
    q_splits = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + (q_indices[:, None, None] - q_start_padding) * q_stride_i
        + d_splits_indices[None, :, :] * q_stride_d,
        mask=(
            (query_offset <= q_positions[:, None, None]) &
            (q_positions[:, None, None] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, D_BLOCK_SIZE, head_dim // D_BLOCK_SIZE)

    for it in range(n_iterations):
        begin_block_indices, end_block_indices = iteration(
            begin_block_indices, end_block_indices,
            q_splits,
            k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,
            n, h, i,
            top_k, head_dim,
            query_offset, q_block_offset, q_start_padding, q_len, k_len,
            block_size_q, block_size_k, D_BLOCK_SIZE,
            start_sink_tokens, end_sink_tokens
        )

    # Do not select nodes with length 0
    block_length = end_block_indices - begin_block_indices
    result = tl.where(block_length > 0, begin_block_indices, -1)

    tl.store(
        r_buffer
        + n * r_stride_n
        + h * r_stride_h
        + i * r_stride_i
        + tl.arange(0, top_k) * r_stride_t,
        result,
    )


def mask_gen_triton_impl(query_states, key_states,
                         top_k: int, block_size_q: int, block_size_k: int, query_offset: int,
                         start_sink_tokens: int, end_sink_tokens: int, out=None):
    """
    Top-k key selection
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param top_k: number of key blocks to select per query
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query sequence
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param out: output tensor to store the mask_block_indices
    :return: mask_block_indices (bsz, num_heads, k_len, top_k)
    """
    device = query_states.device

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, _ = key_states.size()

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    n_iterations = int(math.ceil(math.log2((k_blocks - (start_sink_tokens + end_sink_tokens) // block_size_k) / top_k)))
    q, k = query_states, key_states
    m = out
    if m is None:
        m = torch.zeros(bsz, num_heads, q_blocks, top_k, dtype=torch.int32, device=device)

    D_BLOCK_SIZE = 16

    grid = (q_blocks, bsz, num_heads)
    with SetDevice(query_states.device):
        compile_info = mask_gen_kernel[grid](
            q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_d=q.stride(3),
            k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
            r_buffer=m, r_stride_n=m.stride(0), r_stride_h=m.stride(1), r_stride_i=m.stride(2), r_stride_t=m.stride(3),
            bsz=bsz, num_heads=num_heads, q_len=q_len, k_len=k_len, head_dim=head_dim, top_k=top_k, query_offset=query_offset,
            block_size_q=block_size_q, block_size_k=block_size_k, D_BLOCK_SIZE=D_BLOCK_SIZE, n_iterations=n_iterations,
            start_sink_tokens=start_sink_tokens, end_sink_tokens=end_sink_tokens,
            num_stages=2,
            num_warps=16,
        )

    return m, compile_info


def mask_gen_triton(query_states, key_states,
                    top_k: int, block_size_q: int, block_size_k: int, query_offset: int,
                    start_sink_tokens: int, end_sink_tokens: int, out=None):
    """
    Top-k key selection
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param top_k: number of key blocks to select per query
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query sequence
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param out: output tensor to store the mask_block_indices
    :return: mask_block_indices (bsz, num_heads, k_len, top_k)
    """

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, _ = key_states.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    return mask_gen_triton_impl(
        query_states[:, :, cutoff:], key_states,
        top_k, block_size_q, block_size_k, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens, out=out
    )
