""" Vision Transformer (ViT) in PyTorch

A PyTorch implement of Vision Transformers as described in:

'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
    - https://arxiv.org/abs/2010.11929

`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
    - https://arxiv.org/abs/2106.10270

`FlexiViT: One Model for All Patch Sizes`
    - https://arxiv.org/abs/2212.08013

The official jax code is released and available at
  * https://github.com/google-research/vision_transformer
  * https://github.com/google-research/big_vision

Acknowledgments:
  * The paper authors for releasing code and weights, thanks!
  * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
  * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  * Bert reference code checks against Huggingface Transformers and Tensorflow Bert

Hacked together by / Copyright 2020, Ross Wightman
"""
import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
import os
import matplotlib.pyplot as plt
import copy

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
    OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
    trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
    get_act_layer, get_norm_layer, LayerType
from timm.models._builder import build_model_with_cfg
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
import time

_logger = logging.getLogger(__name__)


def compute_distance_matrix(mask):
    """
    计算每个数据中其他 token 到重要 token 的欧式距离之和。

    参数:
    mask (torch.Tensor): 形状为 (bsz, w, h) 的 mask，值为 1 的位置表示重要 token。

    返回:
    torch.Tensor: 形状为 (bsz, w, h) 的结果，表示每个 token 到重要 token 的欧式距离之和。
    """
    bsz, w, h = mask.shape

    # 获取重要 token 的坐标
    # 使用 torch.nonzero 获取重要 token 的坐标，形状为 (num_important_tokens, 3)
    # 其中第一维是 batch index，第二维和第三维是坐标
    important_indices = torch.nonzero(mask)

    # 创建一个形状为 (bsz, w, h, 2) 的网格，表示每个 token 的坐标
    grid_y, grid_x = torch.meshgrid(torch.arange(w), torch.arange(h))
    grid = torch.stack([grid_x, grid_y], dim=-1).float()  # (w, h, 2)
    grid = grid.unsqueeze(0).expand(bsz, -1, -1, -1)  # (bsz, w, h, 2)

    # 初始化结果张量
    result = torch.zeros_like(mask, dtype=torch.float32)

    # 遍历每个 batch
    for i in range(bsz):
        # 获取当前 batch 的重要 token 坐标
        batch_indices = important_indices[important_indices[:, 0] == i, 1:]

        if batch_indices.shape[0] == 0:
            # 如果没有重要 token，距离之和为 0
            continue

        # 将重要 token 的坐标转换为 (num_important_tokens, 2)
        important_coords = batch_indices.float()

        # 计算每个 token 到重要 token 的欧式距离
        # grid[i] 的形状为 (w, h, 2)，important_coords 的形状为 (num_important_tokens, 2)
        # 使用广播机制计算距离
        diff = grid[i].unsqueeze(2) - important_coords.unsqueeze(0).unsqueeze(0)  # (w, h, num_important_tokens, 2)
        distances = torch.sqrt((diff ** 2).sum(dim=-1))  # (w, h, num_important_tokens)

        # 对每个 token 的距离求和
        result[i] = distances.sum(dim=-1)

    return result.to(mask.device)

class Attention(nn.Module):
    fused_attn: Final[bool]
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            if current['layer_idx'] == 3:
                current['pre_cache_v'] = copy.deepcopy(cache_dic['cache'][current['layer_idx']]['v'])

            if torch.all(current['update_mask'] == 1):
                cache_dic['cache'][current['layer_idx']]['k'] = k
                cache_dic['cache'][current['layer_idx']]['v'] = v
            else:
                mask_kv = current['update_mask'].unsqueeze(1).unsqueeze(-1).expand(-1, self.num_heads, -1, self.head_dim).bool()
                cache_dic['cache'][current['layer_idx']]['k'].masked_scatter_(mask_kv, k)
                k = cache_dic['cache'][current['layer_idx']]['k']
                cache_dic['cache'][current['layer_idx']]['v'].masked_scatter_(mask_kv, v)
                v = cache_dic['cache'][current['layer_idx']]['v']

        else:
            B, N, C = x.shape
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q, k = self.q_norm(q), self.k_norm(k)
            if current['cfg_cache']:
                cache_dic['cache'][current['layer_idx']]['k'] = k[:int(B/2)]
                cache_dic['cache'][current['layer_idx']]['v'] = v[:int(B/2)]
            else:
                cache_dic['cache'][current['layer_idx']]['k'] = k[:B]
                cache_dic['cache'][current['layer_idx']]['v'] = v[:B]

        if False:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = torch.matmul(q, k.transpose(-2, -1))
            # attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = torch.matmul(attn, v)
            # x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class LayerScale(nn.Module):
    def __init__(
            self,
            dim: int,
            init_values: float = 1e-5,
            inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: nn.Module = nn.GELU,
            norm_layer: nn.Module = nn.LayerNorm,
            mlp_layer: nn.Module = Mlp,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor, current, cache_dic) -> torch.Tensor:
        B, N, C = x.size()
        if current['token_cache'] and not current['is_force_fresh']:
            current['module'] = 'attn'
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), current, cache_dic)))
            # if current['depth'] == 20:
            #     if current['layer_idx'] == 3  and current['step'] % 3 != 0:
            #         x = self.pruning_q(current, cache_dic, x)
            # else:
            if current['layer_idx'] == 3:
                x = self.pruning_q(current, cache_dic, x)
            current['module'] = 'mlp'
            x_temp = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
            x = x + x_temp
            # if current['depth'] == 20:
            #     if current['layer_idx'] == current['depth'] - 1 and current['step'] % 3 != 0:
            #         x = self.unpruning_q(current, cache_dic, x)
            # else:
            if current['layer_idx'] == current['depth'] - 1:
                # B, N, C = x.size()
                if current['cal_caching_num']:
                    current['remainging_token'] = N
                x = self.unpruning_q(current, cache_dic, x)
        else:
            current['module'] = 'attn'
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), current, cache_dic)))
            current['module'] = 'mlp'
            cache_dic['cache'][current['layer_idx']]['mlp'] = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
            x = x + cache_dic['cache'][current['layer_idx']]['mlp']
        return x

    def pruning_q(self, current, cache_dic, x):
        B, N, C = x.shape
        cos_sim = F.cosine_similarity(cache_dic['cache'][current['layer_idx']]['v'], current['pre_cache_v'], dim=-1)
        similarity = cos_sim.mean(dim=1)
        similarity[current['mask_to_pred_mask']] = 0
        similarity_ = similarity.reshape(B, -1)
        # similarity_ = torch.rand_like(similarity_)
        # cos = compute_distance_matrix(current['prev_mask_to_pred_mask'])

        # # # # 定义独立的随机数生成器
        # generator = torch.Generator(device=similarity_.device)
        # seed = int(current['layer_idx'] * 1000 + current['step'] * 1000000 + current['mask_to_pred_len'])
        # generator.manual_seed(seed)
        # noise = torch.normal(mean=0.0, std=0.2, size=similarity_.size(), generator=generator,
        #                      device=similarity_.device)
        # similarity_ = similarity_ + noise

        similarity_[current['mask_to_pred_mask']] = 0
        similarity_[current['prev_mask_to_pred_mask']] = 0
        # similarity_[current['predicted_mask']] += 0.07
        similarity_ = similarity_[(current['update_mask']).nonzero(as_tuple=True)].reshape(B, -1)
        indsss, inds = torch.sort(similarity_, dim=-1, descending=False)
        if current['num_iter'] == 32:
            #for huge
            retain_ratio_lst = [
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
                0.2, 0.2, 0.2, 0.2, 0.2, 0.15, 0.15, 0.15, 0.15, 0.15,
                0.15, 0.15
            ]
        else:
            # retain_ratio_lst = [
            #     1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
            #     0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
            #     0.05, 0.05, 0.05, 0.05
            # ]
            # retain_ratio_lst = [
            #     1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            #     1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            #     0.4, 0.4, 0.4, 0.4, 0.4, 0.35, 0.35, 0.35, 0.35, 0.35,
            #     0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1
            # ]

            retain_ratio_lst = [
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                0.6, 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4,
                0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12,
                0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
                0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
                0.05, 0.05, 0.05, 0.05
            ]

            # retain_ratio_lst = [
            #     1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            #     1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            #     0.4, 0.4, 0.4, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35,
            #     0.15, 0.15, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12,
            #     0.12, 0.12, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
            #     0.1, 0.1, 0.1, 0.1
            # ]

        cur_ratio = retain_ratio_lst[current['step']]
        fresh_num = torch.maximum(current['mask_to_pred_len'] + current['prev_mask_to_pred_len'] + 15,
                                  torch.tensor(int((inds.shape[1]) * cur_ratio)).to(current['mask_to_pred_len']))
        inds = inds[:, :fresh_num]
        imp_inds = inds
        #bsz, q_num
        next_mask = torch.zeros((B, N), device=x.device)
        next_mask = next_mask.scatter_(1, imp_inds, 1)
        current['next_mask'] = next_mask
        new_update_mask = torch.zeros_like(current['update_mask'], device=x.device).bool()
        new_update_mask.masked_scatter_(current['update_mask'], next_mask.bool())
        current['origi_update_mask'] = current['update_mask']
        current['update_mask'] = new_update_mask
        pruning_x = torch.masked_select(x, next_mask.unsqueeze(-1).expand(-1,-1,C).bool()).reshape(B, -1, C)
        return pruning_x

    def unpruning_q(self, current, cache_dic, pruning_x):
        B, _, C = pruning_x.shape
        _, N = current['next_mask'].shape
        new_x_full = torch.zeros((B, N, C), device=pruning_x.device)
        new_x_full.masked_scatter_(current['next_mask'].unsqueeze(-1).expand(-1,-1,C).bool(), pruning_x)
        current['next_mask'] = None
        current['update_mask'] = current['origi_update_mask']
        current['origi_update_mask'] = None
        return new_x_full