"""
Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in:
"""
import logging
from typing import Optional
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
import copy
from timm.layers import Mlp, DropPath, use_fused_attn


_logger = logging.getLogger(__name__)


class Attention(nn.Module):
    fused_attn: Final[bool]
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            if current['layer_idx'] == 3:
                current['pre_cache_v'] = copy.deepcopy(cache_dic['cache'][current['layer_idx']]['v'])

            if torch.all(current['update_mask'] == 1):
                cache_dic['cache'][current['layer_idx']]['k'] = k
                cache_dic['cache'][current['layer_idx']]['v'] = v
            else:
                mask_kv = current['update_mask'].unsqueeze(1).unsqueeze(-1).expand(-1, self.num_heads, -1, self.head_dim).bool()
                cache_dic['cache'][current['layer_idx']]['k'].masked_scatter_(mask_kv, k)
                k = cache_dic['cache'][current['layer_idx']]['k']
                cache_dic['cache'][current['layer_idx']]['v'].masked_scatter_(mask_kv, v)
                v = cache_dic['cache'][current['layer_idx']]['v']

        else:
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            if current['cfg_cache']:
                cache_dic['cache'][current['layer_idx']]['k'] = k[:int(B/2)]
                cache_dic['cache'][current['layer_idx']]['v'] = v[:int(B/2)]
            else:
                cache_dic['cache'][current['layer_idx']]['k'] = k[:B]
                cache_dic['cache'][current['layer_idx']]['v'] = v[:B]

        if True:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = torch.matmul(q, k.transpose(-2, -1))
            # attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = torch.matmul(attn, v)
            # x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(
            self,
            dim: int,
            init_values: float = 1e-5,
            inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: nn.Module = nn.GELU,
            norm_layer: nn.Module = nn.LayerNorm,
            mlp_layer: nn.Module = Mlp,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        if current['token_cache'] and not current['is_force_fresh']:
            current['module'] = 'attn'
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), current, cache_dic)))
            if current['layer_idx'] == 3:
                x = self.pruning_q(current, cache_dic, x)
            current['module'] = 'mlp'
            x_temp = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
            x = x + x_temp

            if current['layer_idx'] == current['depth'] - 1:
                x = self.unpruning_q(current, cache_dic, x)
        else:
            current['module'] = 'attn'
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), current, cache_dic)))
            current['module'] = 'mlp'
            cache_dic['cache'][current['layer_idx']]['mlp'] = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
            x = x + cache_dic['cache'][current['layer_idx']]['mlp']
        return x

    def pruning_q(self, current, cache_dic, x):
        B, N, C = x.shape
        cos_sim = F.cosine_similarity(cache_dic['cache'][current['layer_idx']]['v'], current['pre_cache_v'], dim=-1)
        similarity = cos_sim.mean(dim=1)
        similarity[current['to_pred_mask']] = 0
        similarity[current['prev_pred_mask']] = 0
        similarity = similarity[(current['update_mask']).nonzero(as_tuple=True)].reshape(B, -1)
        indsss, inds = torch.sort(similarity, dim=-1, descending=False)
        retain_ratio_lst = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                            0.4, 0.4, 0.4, 0.4, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
                            0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
                            0.1, 0.1
                            ]
        cur_ratio = retain_ratio_lst[current['step']]
        fresh_num = torch.maximum(current['to_pred_len'] + current['prev_pred_len'] + 20,
                                  torch.tensor(int((inds.shape[1]) * cur_ratio)).to(current['to_pred_len']))
        inds = inds[:, :fresh_num]
        imp_inds = inds
        next_mask = torch.zeros((B, N), device=x.device)
        next_mask = next_mask.scatter_(1, imp_inds, 1)
        current['next_mask'] = next_mask
        new_update_mask = torch.zeros_like(current['update_mask'], device=x.device).bool()
        new_update_mask.masked_scatter_(current['update_mask'], next_mask.bool())
        current['origi_update_mask'] = current['update_mask']
        current['update_mask'] = new_update_mask
        pruning_x = torch.masked_select(x, next_mask.unsqueeze(-1).expand(-1,-1,C).bool()).reshape(B, -1, C)
        return pruning_x

    def unpruning_q(self, current, cache_dic, pruning_x):
        B, _, C = pruning_x.shape
        _, N = current['next_mask'].shape
        new_x_full = torch.zeros((B, N, C), device=pruning_x.device, dtype=pruning_x.dtype)
        new_x_full.masked_scatter_(current['next_mask'].unsqueeze(-1).expand(-1,-1,C).bool(), pruning_x)
        current['next_mask'] = None
        current['update_mask'] = current['origi_update_mask']
        current['origi_update_mask'] = None
        return new_x_full