import torch
import torch.nn as nn
import math
import torch.nn.functional as F

def moba_attn_varlen_naive(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: int,
    moba_chunk_size: int,
    moba_topk: int,
    dropout_p: float = 0.0,
    training: bool = True,
) -> torch.Tensor:
    """
    参考题目中给出的 moba_attn_varlen_naive 函数

    q, k, v 形状为 [seqlen, head, head_dim]
    cu_seqlens 为累积序列长度
    max_seqlen 表示批量中最长的序列长度
    moba_chunk_size, moba_topk 为 MOBA 相关超参
    dropout_p 为注意力分数的丢弃概率
    training 标记是否处于训练模式

    返回张量维度同 q
    """

    batch = cu_seqlens.numel() - 1
    softmax_scale = q.shape[-1] ** (-0.5)
    o = torch.zeros_like(q)

    for batch_idx in range(batch):
        batch_start = cu_seqlens[batch_idx].item()
        batch_end = cu_seqlens[batch_idx + 1].item()

        q_ = q[batch_start:batch_end]
        k_ = k[batch_start:batch_end]
        v_ = v[batch_start:batch_end]
        o_ = o[batch_start:batch_end]

        key_gate_weight = []
        batch_size = batch_end - batch_start
        num_block = math.ceil(batch_size / moba_chunk_size)

        for block_idx in range(num_block):
            block_start = block_idx * moba_chunk_size
            block_end = min(batch_size, block_start + moba_chunk_size)
            key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True))

        key_gate_weight = torch.cat(key_gate_weight, dim=0)  # [N, H, D]

        # calc & mask gate
        q_ = q_.float()
        key_gate_weight = key_gate_weight.float()
        gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight)  # [H, S, N]

        key_gate_weight = key_gate_weight.type_as(k)
        q_ = q_.type_as(k)

        for i in range(num_block):
            # 在块 i 之前的序列全部无法关注 (置为 -inf)
            gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf")
            # 块 i 本身置为 inf 保证一定被选中
            gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf")

        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False
        )
        gate_top_k_val, _ = gate_top_k_val.min(dim=-1)  # [H, S]
        need_attend = gate >= gate_top_k_val.unsqueeze(-1)

        gate_idx_mask = torch.zeros(
            need_attend.shape, dtype=torch.bool, device=q.device
        )
        gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True)
        need_attend = torch.logical_and(need_attend, gate_idx_mask)

        gate[need_attend] = 0
        gate[~need_attend] = -float("inf")
        gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[:, :, :batch_size]  # [H, S, S]
        gate.masked_fill_(
            torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf")
        )

        q_ = q_.float()
        k_ = k_.float()
        v_ = v_.float()
        qk = torch.einsum("xhd,yhd->hxy", q_, k_)
        qk += gate
        qk *= softmax_scale
        p = qk.softmax(dim=-1)

        # dropout on attention weights
        if dropout_p > 0.0 and training:
            p = F.dropout(p, p=dropout_p, training=True)

        o_ += torch.einsum("hxy,yhd->xhd", p, v_)
        o = o.type_as(q)

    return o

class MobaMultiHeadAttention(nn.Module):
    """
    基于 moba 和 torch.nn.MultiheadAttention 实现的多头注意力模块。
    输入:
        query: (L, N, E_q)
        key:   (S, N, E_k)
        value: (S, N, E_v)
        cu_seqlens: 累积序列长度 (用于处理可变长度的批次)
        max_seqlen: 当前批量中序列的最大长度
    其中 L 可以与 S 不相等。
    """

    def __init__(
        self, 
        embed_dim, 
        num_heads, 
        moba_chunk_size=16, 
        moba_topk=2, 
        dropout_p=0.0
    ):
        super(MobaMultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # 投影层
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # MOBA 参数
        self.moba_chunk_size = moba_chunk_size
        self.moba_topk = moba_topk

        # 注意力 dropout 参数
        self.dropout_p = dropout_p

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor
    ) -> torch.Tensor:
        """
        输入:
            query: (L, N, E_q)
            key:   (S, N, E_k)
            value: (S, N, E_v)
        返回:
            output: (L, N, E_q)
        """

        L, N, _ = query.shape
        S, N_k, _ = key.shape
        # 简化处理: 假设 E_q == E_k == E_v == embed_dim
        assert N == N_k, "query,key,batch大小应一致"

        seqlen = query.size(0)
        batch_size = query.size(1)

        # 分别对 Q, K, V 做线性投影
        q = self.q_proj(query)  # (L, N, E)
        k = self.k_proj(key)    # (S, N, E)
        v = self.v_proj(value)  # (S, N, E)

        # 调整形状到 [seqlen, head, head_dim]
        q = q.transpose(0, 1).contiguous().view(-1, self.embed_dim)
        q = q.view(N * L, self.num_heads, self.head_dim)

        k = k.transpose(0, 1).contiguous().view(-1, self.embed_dim)
        k = k.view(N * S, self.num_heads, self.head_dim)

        v = v.transpose(0, 1).contiguous().view(-1, self.embed_dim)
        v = v.view(N * S, self.num_heads, self.head_dim)

        total_length = batch_size * seqlen
        
        # Build the cumulative sequence lengths tensor.
        # Assumes that each batch has the same sequence length.
        
        cu_seqlens = torch.arange(0, total_length + 1, step=seqlen, device=query.device)
        # 调用 moba_attn_varlen_naive
        out = moba_attn_varlen_naive(
            q,
            k,
            v,
            cu_seqlens,
            seqlen,
            self.moba_chunk_size,
            self.moba_topk,
            dropout_p=self.dropout_p,
            training=self.training,
        )
        # out 形状为 [N*L, num_heads, head_dim]

        # 还原形状 => (L, N, E_q)
        out = out.view(N, L, self.num_heads * self.head_dim)
        out = out.transpose(0, 1).contiguous()

        # 最终线性投影
        out = self.out_proj(out)
        return out