# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import torch
from minference import streaming_forward, vertical_slash_sparse_attention

last_q = 64
arange = torch.arange(last_q, device=f"cuda:{torch.cuda.current_device()}")
LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]


def sum_all_diagonal_matrix(mat: torch.tensor):
    b, h, n, m = mat.shape
    zero_mat = torch.zeros((b, h, n, n)).to(mat.device)  # Zero matrix used for padding
    mat_padded = torch.cat(
        (zero_mat, mat, zero_mat), -1
    )  # pads the matrix on left and right
    mat_strided = mat_padded.as_strided(
        (1, 1, n, n + m), (1, n * (2 * n + m), 2 * n + m + 1, 1)
    )  # Change the strides
    sum_diags = torch.sum(mat_strided, 2)  # Sums the resulting matrix's columns
    return sum_diags[:, :, 1:]


def get_vertical_and_slash_idx(
    q, k, v, vertical_size, slash_size, keep_global=30, keep_local=100
):
    q_len = q.shape[2]
    bsz = q.shape[0]
    vertical_size, slash_size = min(q_len, max(vertical_size, keep_global)), min(
        q_len, max(slash_size, keep_local)
    )
    last_q = min(64, q_len)
    qk = torch.einsum(f"bhmk, bhnk -> bhmn", q[:, :, -last_q:, :], k) / math.sqrt(
        q.shape[-1]
    )
    qk[:, :, :, -last_q:] = torch.where(
        LAST_Q_MASK[..., -last_q:, -last_q:].to(q.device),
        qk[:, :, :, -last_q:],
        -torch.inf,
    )
    qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
    vertical = qk.sum(-2, keepdim=True)
    vertical[..., :keep_global] = torch.inf
    vertical_topk = torch.topk(vertical, vertical_size, -1).indices

    slash = sum_all_diagonal_matrix(qk)[..., : -last_q + 1]
    slash[..., -keep_local:] = torch.inf
    # slash_topk = slash
    slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices

    return vertical_topk, slash


def minfer_attention(
    query: torch.Tensor,  # [BATCH, N_HEADS, N_CTX, D_HEAD]
    key: torch.Tensor,  # [BATCH, N_HEADS, N_CTX, D_HEAD]
    value: torch.Tensor,  # [BATCH, N_HEADS, N_CTX, D_HEAD]
    minfer_config: dict,
    sparse_file=None,
):
    attn_out = torch.zeros_like(query)
    batch_size, num_heads, seq_len, head_dim = query.shape
    gqa_groups = query.shape[1] // key.shape[1]
    assert batch_size == 1
    for h in range(num_heads):
        kh = h // gqa_groups
        cfg = minfer_config[str(h)]
        if cfg[0] == "vertical_and_slash":
            vertical_size, slash_size = cfg[1], cfg[2]
            v_idx, s_idx = get_vertical_and_slash_idx(
                query[:, h : h + 1],
                key[:, kh : kh + 1],
                value[:, kh : kh + 1],
                vertical_size,
                slash_size,
            )
            attn_out[:, h : h + 1], block_count = vertical_slash_sparse_attention(
                query[:, h : h + 1],
                key[:, kh : kh + 1],
                value[:, kh : kh + 1],
                v_idx,
                s_idx,
            )
            if sparse_file is not None:
                sparse_ratio = torch.sum(block_count)/((1+block_count.size(-1))*block_count.size(-1)/2)
                with open(sparse_file, 'a') as f:
                    f.write(f"{seq_len}\t{sparse_ratio}\n")

        elif cfg[0] == "stream_llm":
            n_init, n_local = cfg[1], cfg[2]
            attn_out[:, h : h + 1] = streaming_forward(
                query[:, h : h + 1],
                key[:, kh : kh + 1],
                value[:, kh : kh + 1],
                n_init,
                n_local,
            )
        else:
            raise ValueError(f"unkonwn minfer config: {cfg[0]}")
    return attn_out
