import math

import torch
import triton
import einx

from ..common import calc_dims
from ..utils.softsort import softsort


def mask_gen_ref(query_states, key_states,
                 top_k: int, block_size_q: int, block_size_k: int, query_offset: int,
                 start_sink_tokens: int, end_sink_tokens: int,
                 soft_sort: bool = False, soft_sort_tau: float = 1.0, soft_sort_pow: float = 1.0):
    """
    Reference top-k key selection
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param top_k: number of key blocks to select per query
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param soft_sort: whether to use soft differentiable sort
    :param soft_sort_tau: temperature for soft sort
    :param soft_sort_pow: power coefficient for soft sort
    :return: mask_block_indices (bsz, num_heads, q_len, top_k)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, _ = key_states.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    if soft_sort:
        permute = []
        sparse_indices = []
    else:
        sparse_indices = []

    BLOCK_SIZE = 512
    for start in range(cutoff, q_len, BLOCK_SIZE):
        end = min(q_len, start + BLOCK_SIZE)

        result = mask_gen_ref_impl(
            query_states[:, :, start:end], key_states,
            top_k, block_size_q, block_size_k, query_offset + start,
            start_sink_tokens, end_sink_tokens,
            soft_sort, soft_sort_tau, soft_sort_pow
        )
        if soft_sort:
            permute.append(result[0])
            sparse_indices.append(result[1])
        else:
            sparse_indices.append(result)

    if soft_sort:
        return torch.cat(permute, dim=2), torch.cat(sparse_indices, dim=2)
    else:
        return torch.cat(sparse_indices, dim=2)


def mask_gen_ref_impl(query_states, key_states,
                      top_k: int, block_size_q: int, block_size_k: int, query_offset: int,
                      start_sink_tokens: int, end_sink_tokens: int,
                      soft_sort: bool = False, soft_sort_tau: float = 1.0, soft_sort_pow: float = 1.0):
    """
    Reference top-k key selection
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param top_k: number of key blocks to select per query
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param soft_sort: whether to use soft differentiable sort
    :param soft_sort_tau: temperature for soft sort
    :param soft_sort_pow: power coefficient for soft sort
    :return: mask_block_indices (bsz, num_heads, k_len, top_k)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, _ = key_states.size()

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    # Blockify
    query_states = torch.nn.functional.pad(query_states, (0, 0, q_start_padding, q_end_padding))
    query_states = query_states.reshape(bsz, num_heads, q_blocks, block_size_q, head_dim)
    key_states = torch.nn.functional.pad(key_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    key_states = key_states.reshape(bsz, num_heads, k_blocks, block_size_k, head_dim)

    expanded_key_states = key_states.unsqueeze(2).expand(-1, -1, q_blocks, -1, -1, -1)
    # (bsz, num_heads, q_blocks, k_blocks, block_size_k, head_dim)

    # Initial state
    ((begin_block_indices_branch, end_block_indices_branch),
     rep_blocks, rep_block_indices) = init_branches(
        expanded_key_states,
        q_block_offset, block_size_q, top_k,
        start_sink_tokens, end_sink_tokens
    )

    n_iterations = int(math.ceil(math.log2((k_blocks - (start_sink_tokens + end_sink_tokens) // block_size_k) / top_k)))
    for it in range(n_iterations):
        attn_scores = compute_attn_scores(
            query_states, rep_blocks, rep_block_indices,
            end_block_indices_branch - begin_block_indices_branch,
            query_offset, q_block_offset, q_len, k_len,
            start_sink_tokens, end_sink_tokens,
        )

        if soft_sort:
            if it == n_iterations - 1:
                return soft_topk_iteration(
                    attn_scores,
                    begin_block_indices_branch, end_block_indices_branch,
                    expanded_key_states,
                    start_sink_tokens,
                    soft_sort_tau, soft_sort_pow,
                    is_final_iteration=True,
                )

            ((begin_block_indices_branch, end_block_indices_branch),
             rep_blocks, rep_block_indices) = soft_topk_iteration(
                attn_scores,
                begin_block_indices_branch, end_block_indices_branch,
                expanded_key_states,
                start_sink_tokens,
                soft_sort_tau, soft_sort_pow,
                is_final_iteration=False,
            )

        else:
            if it == n_iterations - 1:
                return topk_iteration(
                    attn_scores,
                    begin_block_indices_branch, end_block_indices_branch,
                    expanded_key_states,
                    start_sink_tokens,
                    is_final_iteration=True,
                )

            ((begin_block_indices_branch, end_block_indices_branch),
             rep_blocks, rep_block_indices) = topk_iteration(
                attn_scores,
                begin_block_indices_branch, end_block_indices_branch,
                expanded_key_states,
                start_sink_tokens,
                is_final_iteration=False,
            )


def init_branches(expanded_key_states: torch.Tensor,
                  q_block_offset, block_size_q, top_k, start_sink_tokens, end_sink_tokens):
    device = expanded_key_states.device
    bsz, num_heads, q_blocks, k_blocks, block_size_k, head_dim = expanded_key_states.shape

    # ceil( [1 ... q_blocks] * (block_size_q / block_size_k) )
    n_keys_per_query = q_block_offset + torch.arange(1, q_blocks + 1, device=device)
    key_blocks_per_query = triton.cdiv(
        torch.clip(n_keys_per_query * block_size_q - start_sink_tokens - end_sink_tokens, min=0),
        block_size_k
    )

    # (bsz, num_heads, q_blocks, top_k)
    begin_block_indices = (key_blocks_per_query[:, None] * torch.arange(0, top_k, device=device) // top_k)
    begin_block_indices = begin_block_indices[None, None, :, :].expand(bsz, num_heads, -1, -1)
    end_block_indices = (key_blocks_per_query[:, None] * torch.arange(1, top_k + 1, device=device) // top_k)
    end_block_indices = end_block_indices[None, None, :, :].expand(bsz, num_heads, -1, -1)

    begin_block_indices_branch, end_block_indices_branch = (
        torch.cat([begin_block_indices, (begin_block_indices + end_block_indices) // 2], 3),
        torch.cat([(begin_block_indices + end_block_indices) // 2, end_block_indices], 3),
    )  # (bsz, num_heads, q_blocks, top_k*2)

    # Select representative block for each node
    rep_block_indices = begin_block_indices_branch

    rep_blocks = einx.get_at(
        'n h q [i] k d, n h q t -> n h q t k d',
        expanded_key_states, start_sink_tokens // block_size_k + rep_block_indices
    )  # (bsz, num_heads, q_blocks, top_k*2, block_size_k, head_dim)

    return (begin_block_indices_branch, end_block_indices_branch), rep_blocks, rep_block_indices


def compute_attn_scores(query_states: torch.Tensor,
                        rep_blocks: torch.Tensor,
                        rep_block_indices: torch.Tensor,
                        len_block_indices_branch: torch.Tensor,
                        query_offset, q_block_offset, q_len, k_len, start_sink_tokens, end_sink_tokens):
    device = query_states.device
    bsz, num_heads, q_blocks, block_size_q, head_dim = query_states.shape
    bsz, num_heads, q_blocks, two_top_k, block_size_k, head_dim = rep_blocks.shape
    top_k = two_top_k // 2

    attn_scores = torch.einsum('nhiqd, nhitkd -> nhitqk', query_states, rep_blocks)
    # (bsz, num_heads, q_blocks, top_k*2, block_size_q, block_size_k)

    # Mask out non-causal pairs
    k_elem_indices = (
            (start_sink_tokens + rep_block_indices * block_size_k)[..., None, None]
            + torch.arange(block_size_k, device=device)
    )  # (bsz, num_heads, q_blocks, top_k*2, 1, block_size_k)
    q_elem_indices = (
            q_block_offset * block_size_q + torch.arange(q_blocks * block_size_q, device=device)
    ).reshape(1, 1, q_blocks, 1, block_size_q, 1)
    attn_scores = torch.where(k_elem_indices <= q_elem_indices - end_sink_tokens, attn_scores, float('-inf'))

    # Mask out padding
    attn_scores = torch.where(k_elem_indices < k_len, attn_scores, float('-inf'))
    attn_scores = torch.where(q_elem_indices >= query_offset, attn_scores, float('-inf'))
    attn_scores = torch.where(q_elem_indices < query_offset + q_len, attn_scores, float('-inf'))

    attn_scores = attn_scores.reshape(bsz, num_heads, q_blocks, top_k * 2, block_size_q * block_size_k).amax(dim=-1)

    # Do not select nodes with length 0
    attn_scores = torch.where(len_block_indices_branch > 0, attn_scores, float('-inf'))
    return attn_scores


def topk_iteration(attn_scores: torch.Tensor,
                   begin_block_indices_branch: torch.Tensor,
                   end_block_indices_branch: torch.Tensor,
                   expanded_key_states: torch.Tensor,
                   start_sink_tokens, is_final_iteration: bool):
    bsz, num_heads, q_blocks, k_blocks, block_size_k, head_dim = expanded_key_states.shape
    bsz, num_heads, q_blocks, two_top_k = attn_scores.shape
    top_k = two_top_k // 2

    _, top_k_indices = torch.topk(attn_scores, top_k, dim=-1, sorted=True)  # (bsz, num_heads, q_blocks, top_k)

    begin_block_indices = begin_block_indices_branch.gather(3, top_k_indices)
    end_block_indices = end_block_indices_branch.gather(3, top_k_indices)

    if is_final_iteration:
        # Do not select nodes with length 0
        node_lengths = end_block_indices - begin_block_indices
        begin_block_indices[node_lengths == 0] = -1
        return begin_block_indices

    begin_block_indices_branch, end_block_indices_branch = (
        torch.cat([begin_block_indices, (begin_block_indices + end_block_indices) // 2], 3),
        torch.cat([(begin_block_indices + end_block_indices) // 2, end_block_indices], 3),
    )  # (bsz, num_heads, q_blocks, top_k*2)

    # Select representative block for each node
    rep_block_indices = begin_block_indices_branch
    gather_indices = start_sink_tokens // block_size_k + rep_block_indices
    # (bsz, num_heads, q_blocks, top_k*2)

    rep_blocks = torch.gather(
        expanded_key_states, 3,
        gather_indices[..., None, None].expand(-1, -1, -1, -1, block_size_k, head_dim),
    )  # (bsz, num_heads, q_blocks, top_k*2, block_size_k, head_dim)

    return (begin_block_indices_branch, end_block_indices_branch), rep_blocks, rep_block_indices


def soft_topk_iteration(attn_scores: torch.Tensor,
                        begin_block_indices_branch: torch.Tensor,
                        end_block_indices_branch: torch.Tensor,
                        expanded_key_states: torch.Tensor,
                        start_sink_tokens, soft_sort_tau, soft_sort_pow,
                        is_final_iteration: bool):
    bsz, num_heads, q_blocks, k_blocks, block_size_k, head_dim = expanded_key_states.shape
    bsz, num_heads, q_blocks, two_top_k = attn_scores.shape
    top_k = two_top_k // 2

    permute = softsort(
        attn_scores.reshape(-1, top_k * 2), tau=soft_sort_tau, pow=soft_sort_pow, hard=True,
    ).reshape(bsz, num_heads, q_blocks, top_k * 2, top_k * 2)[..., :top_k, :]
    # (bsz, num_heads, q_blocks, top_k, top_k*2)

    if is_final_iteration:
        if False:
            # Do not select nodes with length 0
            top_k_indices = permute.argmax(dim=-1)  # (bsz, num_heads, q_blocks, top_k)

            begin_block_indices = begin_block_indices_branch.gather(3, top_k_indices)
            end_block_indices = end_block_indices_branch.gather(3, top_k_indices)

            node_lengths = end_block_indices - begin_block_indices
            begin_block_indices[node_lengths == 0] = -1
            return begin_block_indices

        else:
            # Do not select nodes with length 0
            branch_lengths = end_block_indices_branch - begin_block_indices_branch  # (bsz, num_heads, q_blocks, top_k*2)
            sparse_indices = torch.where(branch_lengths > 0, begin_block_indices_branch, -1)

            return permute, sparse_indices

    begin_block_indices_branch_next, end_block_indices_branch_next = (
        torch.stack([begin_block_indices_branch, (begin_block_indices_branch + end_block_indices_branch) // 2], -2),
        torch.stack([(begin_block_indices_branch + end_block_indices_branch) // 2, end_block_indices_branch], -2),
    )  # (bsz, num_heads, q_blocks, 2, top_k*2)

    rep_block_indices_next = begin_block_indices_branch_next
    gather_indices_next = start_sink_tokens // block_size_k + rep_block_indices_next

    # Below is the vectorized implementation of the following pseudocode:
    # for s, i, j in grid(2, top_k, top_k * 2):
    #    rep_blocks[b, h, q, 2*s+i] += (
    #       permute[b, h, q, i, j] * blocked_key_states[b, h, q, gather_indices_next[b, h, q, s, j]]
    #    )
    rep_blocks = einx.dot(
        '... i j, ... s j k d -> ... (s i) k d',
        permute, einx.get_at('... [i] k d, ... s j -> ... s j k d', expanded_key_states, gather_indices_next)
    )

    top_k_indices = permute.argmax(dim=-1)  # (bsz, num_heads, q_blocks, top_k)
    rep_block_indices = einx.get_at('... s [j], ... i -> ... (s i)', rep_block_indices_next, top_k_indices)
    begin_block_indices_branch, end_block_indices_branch = (
        einx.get_at('... s [j], ... i -> ... (s i)', begin_block_indices_branch_next, top_k_indices),
        einx.get_at('... s [j], ... i -> ... (s i)', end_block_indices_branch_next, top_k_indices),
    )

    return (begin_block_indices_branch, end_block_indices_branch), rep_blocks, rep_block_indices
