# Besides, re-arrange the attention module
from torch.jit import Final
from timm.layers import use_fused_attn
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        #self.fused_attn = use_fused_attn() #这个直接调到C++里面去了,不好取attn_map出来
        self.fused_attn = False
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, cache_dic, current, fresh_indices=None) -> torch.Tensor:
    # 0.4ms extra cost on A800, mainly tensor operations
        """
        fresh_indices: (B, fresh_ratio*N), the index tensor for the fresh tokens
        """
        B, N, C = x.shape
        
        if fresh_indices is not None:


            N= fresh_indices.shape[1]

            sorted_indices_tokens = fresh_indices.argsort(dim=-1, descending=False)
            x = torch.gather(input = x, dim = 1, index = sorted_indices_tokens.unsqueeze(-1).expand(-1, -1, x.shape[-1]) )  #(B, fresh_ratio*N, hidden_size)

            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)

            q, k, v = qkv.unbind(0)   #q, k, v: (B, num_heads, fresh_ratio*N, head_dim)

            sorted_indices_qkv_expanded = fresh_indices.sort(dim=-1, descending=False)[0].unsqueeze(1).unsqueeze(-1).expand(-1, k.shape[1], -1, k.shape[-1])
            cache_dic['cache'][-1][current['layer']]['k'].scatter_(dim=2, index=sorted_indices_qkv_expanded, src=k)
            k =  cache_dic['cache'][-1][current['layer']]['k']

            cache_dic['cache'][-1][current['layer']]['v'].scatter_(dim=2, index=sorted_indices_qkv_expanded, src=v)
            v =  cache_dic['cache'][-1][current['layer']]['v']

            q, k = self.q_norm(q), self.k_norm(k)

        else:

            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)

            q, k, v = qkv.unbind(0) 

            q, k = self.q_norm(q), self.k_norm(k)
        
        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn_map= attn.softmax(dim=-1) 
            attn = self.attn_drop(attn_map)
            x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C)
        attn_map = attn_map.mean(dim=1) #head mean
        x = self.proj(x)
        x = self.proj_drop(x) 
        return x, attn_map # x: (B, N-M, C), attn_map: (B, N-M, N)