from functools import partial
from itertools import repeat
import collections.abc
from typing import Callable, List, Optional, Tuple, Union
import warnings

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
from torch.jit import Final



'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''
class Transformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
        - https://arxiv.org/abs/2010.11929
    """

    def __init__(
            self,
            num_tokens: int,  # 图像分割的令牌数量
            embed_dim: int = 512,  # 嵌入维度，默认为512
            depth: int = 6,  # ViT中Transformer块的数量，默认为6
            num_heads: int = 8,  # 注意力头的数量，默认为8
            mlp_ratio: float = 2.,  # MLP层的扩张比率，默认为2.0
            qkv_bias: bool = True,  # 是否在QKV线性层中使用偏置，默认为True
            qk_norm: bool = False,  # 是否在计算QK时进行标准化，默认为False
            init_values: Optional[float] = None,  # 初始化值，可选，默认为None
            pre_norm: bool = False,  # 是否在Transformer块之前进行LayerNorm，默认为False
            fc_norm: Optional[bool] = None,  # 是否对全连接层进行标准化，可选，默认为None
            drop_rate: float = 0.,  # 通用的丢弃率，默认为0.0
            pos_drop_rate: float = 0.,  # 位置嵌入丢弃率，默认为0.0
            patch_drop_rate: float = 0.,  # 补丁丢弃率，默认为0.0
            proj_drop_rate: float = 0.,  # 线性投影丢弃率，默认为0.0
            attn_drop_rate: float = 0.,  # 注意力丢弃率，默认为0.0
            drop_path_rate: float = 0.,  # 路径丢弃率，默认为0.0
            weight_init: str = '',  # 权重初始化方法，默认为空字符串
            norm_layer: Optional[Callable] = None,  # 规范化层，可选，默认为None
            act_layer: Optional[Callable] = None,  # 激活函数层，可选，默认为None
    ):
        """
        Args:
            num_tokens: Number of input tokens.
            embed_dim: Transformer embedding dimension.
            depth: Depth of transformer.
            num_heads: Number of attention heads.
            mlp_ratio: Ratio of mlp hidden dim to embedding dim.
            qkv_bias: Enable bias for qkv projections if True.
            init_values: Layer-scale init values (layer-scale enabled if not None).
            class_token: Use class token.
            fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
            drop_rate: Head dropout rate.
            pos_drop_rate: Position embedding dropout rate.
            attn_drop_rate: Attention dropout rate.
            drop_path_rate: Stochastic depth rate.
            weight_init: Weight initialization scheme.
            embed_layer: Patch embedding layey.
            norm_layer: Normalization layer.
            act_layer: MLP activation layer.
            block_fn: Transformer block layer.
        """
        super().__init__()

        self.num_tokens = num_tokens  # 输入的tokens数量
        self.embed_dim = embed_dim  # 嵌入维度，用于保持与其他模型的一致性

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)  # 标准化层，默认为LayerNorm
        act_layer = act_layer or nn.GELU  # 激活函数，默认为GELU

        self.pos_embed = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim) * .02)  # 位置嵌入参数

        self.pos_drop = nn.Dropout(p=pos_drop_rate)  # 位置嵌入的dropout
        if patch_drop_rate > 0:
            self.patch_drop = PatchDropout(
                patch_drop_rate,
                num_prefix_tokens=0,
            )  # 可选的patch dropout
        else:
            self.patch_drop = nn.Identity()  # 如果patch dropout率为0，则使用恒等映射

        self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()  # 是否使用预层标准化

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # 随机深度的衰减规则
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_norm=qk_norm,
                init_values=init_values,
                proj_drop=proj_drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer
            )
            for i in range(depth)])  # 堆叠的Transformer块

        self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()  # 是否使用最终的标准化层

        if weight_init != 'skip':
            self.init_weights(weight_init)  # 初始化模型的权重

    def init_weights(self, mode=''):
        assert mode in ('jax', 'jax_nlhb', 'moco', '')  # 确保初始化模式有效
        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.  # 根据初始化模式设置头部偏置
        trunc_normal_(self.pos_embed, std=.02)  # 对位置嵌入参数进行截断正态分布初始化
        named_apply(get_init_weights_vit(mode, head_bias), self)  # 使用特定的初始化函数对模型进行权重初始化

    def _init_weights(self, m):
        # 兼容性函数，留给下游用户使用
        init_weights_vit_timm(m)

    @torch.jit.ignore
    def no_weight_decay(self):
        # 定义不需要进行权重衰减的参数集合
        return {'pos_embed', 'cls_token', 'dist_token'}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        # 定义权重衰减分组匹配规则
        return dict(
            stem=r'^cls_token|pos_embed|patch_embed',  # stem和embed部分
            blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]  # Transformer块部分
        )

    def _pos_embed(self, x):
        # 应用位置嵌入到输入张量中，并进行位置dropout
        x = x + self.pos_embed
        return self.pos_drop(x)

    def forward(self, x):
        # 前向传播方法
        x = self._pos_embed(x)  # 添加位置嵌入并进行位置dropout
        x = self.patch_drop(x)  # 应用patch dropout
        x = self.norm_pre(x)  # 应用预层标准化
        x = self.blocks(x)  # 应用Transformer块
        x = self.norm(x)  # 应用最终的标准化层
        return x  # 返回输出张量


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''

class Block(nn.Module):
    def __init__(
            self,
            dim,  # 输入维度
            num_heads,  # 注意力头数
            mlp_ratio=4.,  # MLP隐藏层维度与输入维度的比例
            qkv_bias=False,  # 是否使用注意力矩阵中的偏置项
            qk_norm=False,  # 是否对QK进行归一化
            proj_drop=0.,  # 项目(dropout)概率
            attn_drop=0.,  # 注意力(dropout)概率
            init_values=None,  # 初始化值
            drop_path=0.,  # DropPath概率
            act_layer=nn.GELU,  # 激活函数
            norm_layer=nn.LayerNorm  # 归一化层
    ):
        super().__init__()
        # 第一个子模块：归一化层
        self.norm1 = norm_layer(dim)
        # 第二个子模块：注意力机制
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        # 第三个子模块：层缩放
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # 第四个子模块：DropPath
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # 第五个子模块：归一化层
        self.norm2 = norm_layer(dim)
        # 第六个子模块：MLP
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        # 第七个子模块：层缩放
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # 第八个子模块：DropPath
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        # 第一层：输入 + DropPath(层缩放(注意力(归一化层(输入))))
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        # 第二层：输入 + DropPath(层缩放(MLP(归一化层(输入))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x



'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''
class Attention(nn.Module):
    fast_attn: Final[bool]  # 是否使用快速的注意力计算方法

    def __init__(
            self,
            dim,  # 输入维度
            num_heads=8,  # 注意力头数
            qkv_bias=False,  # 是否使用QKV矩阵的偏置项
            qk_norm=False,  # 是否对QK进行归一化
            attn_drop=0.,  # 注意力(dropout)概率
            proj_drop=0.,  # 项目(dropout)概率
            norm_layer=nn.LayerNorm,  # 归一化层
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads  # 注意力头数
        self.head_dim = dim // num_heads  # 每个头的维度
        self.scale = self.head_dim ** -0.5  # 缩放因子
        self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')  # 是否使用快速的注意力计算方法

        # QKV线性变换
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # 对Q进行归一化
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        # 对K进行归一化
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)  # 注意力(dropout)
        self.proj = nn.Linear(dim, dim)  # 项目线性变换
        self.proj_drop = nn.Dropout(proj_drop)  # 项目(dropout)

    def forward(self, x):
        B, N, C = x.shape
        # 计算QKV
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)  # 拆分QKV
        q, k = self.q_norm(q), self.k_norm(k)  # 对QK进行归一化

        if self.fast_attn:
            # 使用快速的注意力计算方法
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            # 计算注意力矩阵
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)  # 转换维度
        x = self.proj(x)  # 项目线性变换
        x = self.proj_drop(x)  # 项目(dropout)
        return x



'''
Taken from: 
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py
'''
class Mlp(nn.Module):
    """ 用于Vision Transformer、MLP-Mixer等相关网络中的多层感知机（MLP） """
    def __init__(
            self,
            in_features,  # 输入特征维度
            hidden_features=None,  # 隐藏层特征维度，默认与输入特征维度相同
            out_features=None,  # 输出特征维度，默认与输入特征维度相同
            act_layer=nn.GELU,  # 激活函数，默认GELU
            norm_layer=None,  # 归一化层，默认无
            bias=True,  # 是否使用偏置，默认True
            drop=0.,  # Dropout概率，默认0
            use_conv=False,  # 是否使用卷积层，默认False
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)  # 将偏置转换为元组形式
        drop_probs = to_2tuple(drop)  # 将Dropout概率转换为元组形式
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear  # 根据use_conv选择线性层类型

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])  # 第一个线性层
        self.act = act_layer()  # 激活函数
        self.drop1 = nn.Dropout(drop_probs[0])  # Dropout层
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()  # 归一化层
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])  # 第二个线性层
        self.drop2 = nn.Dropout(drop_probs[1])  # Dropout层

    def forward(self, x):
        x = self.fc1(x)  # 第一个线性层
        x = self.act(x)  # 激活函数
        x = self.drop1(x)  # Dropout
        x = self.fc2(x)  # 第二个线性层
        x = self.drop2(x)  # Dropout
        return x



# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''
class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        """
        初始化函数
        Args:
        - dim：输入张量的特征维度
        - init_values：初始值，默认为1e-5
        - inplace：是否原地操作，默认为False
        """
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))  # 缩放系数参数
    def forward(self, x):
        """
        前向传播函数
        Args:
        - x：输入张量
        Returns:
        - x：经过缩放层处理后的张量
        """
        return x.mul_(self.gamma) if self.inplace else x * self.gamma  # 如果是原地操作则使用mul_()方法，否则直接乘以缩放系数


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
'''
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """
    随机深度（Drop paths）函数，用于每个样本上的残差块主路径中的应用。
    Args:
    - x：输入张量
    - drop_prob：丢弃概率，默认为0.
    - training：是否处于训练模式，默认为False
    - scale_by_keep：是否按照保留比例缩放，默认为True
    Returns:
    - x：处理后的张量
    """
    if drop_prob == 0. or not training:
        return x  # 如果丢弃概率为0或者不处于训练模式，则直接返回输入张量
    keep_prob = 1 - drop_prob  # 计算保留的比例
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 为不同维度的张量创建形状
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)  # 生成随机张量用于丢弃路径
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)  # 按照保留比例进行缩放
    return x * random_tensor  # 对输入张量应用随机深度操作


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
'''
class DropPath(nn.Module):
    """
    随机深度（Stochastic Depth）模块，用于在每个样本上（主要应用于残差块的主路径）应用随机深度。
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        """
        初始化DropPath模块。
        Args:
        - drop_prob: 丢弃概率，默认为0.
        - scale_by_keep: 是否按照保留比例缩放，默认为True
        """
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob  # 丢弃概率
        self.scale_by_keep = scale_by_keep  # 是否按照保留比例缩放
    def forward(self, x):
        """
        前向传播函数，应用随机深度操作。
        Args:
        - x: 输入张量
        Returns:
        - x: 处理后的张量
        """
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)  # 调用外部的随机深度函数处理输入张量
    def extra_repr(self):
        """
        返回模块的额外表示信息，用于显示模块的配置参数。
        Returns:
        - str: 模块的额外表示信息字符串
        """
        return f'drop_prob={round(self.drop_prob, 3):0.3f}'  # 返回丢弃概率的额外表示信息

    

'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_dropout.py
'''
class PatchDropout(nn.Module):
    """
    实现基于位置的Dropout模块，参考论文 https://arxiv.org/abs/2212.00794
    """
    return_indices: torch.jit.Final[bool]

    def __init__(
            self,
            prob: float = 0.5,  # Dropout概率，默认为0.5，取值范围为[0, 1)
            num_prefix_tokens: int = 1,  # 排除前缀token的数量，默认为1
            ordered: bool = False,  # 是否按顺序保持Dropout后的token顺序，默认为False
            return_indices: bool = False,  # 是否返回Dropout后保留token的索引，默认为False
    ):
        super().__init__()
        assert 0 <= prob < 1., "Dropout概率应该在[0, 1)范围内"
        self.prob = prob  # Dropout概率
        self.num_prefix_tokens = num_prefix_tokens  # 排除前缀token的数量，例如CLS token
        self.ordered = ordered  # 是否按顺序保持Dropout后的token顺序
        self.return_indices = return_indices  # 是否返回Dropout后保留token的索引

    def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
        # 如果不处于训练状态或者Dropout概率为0，则直接返回输入张量x
        if not self.training or self.prob == 0.:
            if self.return_indices:
                return x, None  # 返回x和None（索引）
            return x  # 返回输入张量x

        # 如果设置了num_prefix_tokens，则从输入张量x中提取前缀token和剩余部分
        if self.num_prefix_tokens:
            prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
        else:
            prefix_tokens = None

        B = x.shape[0]  # batch size
        L = x.shape[1]  # 序列长度
        num_keep = max(1, int(L * (1. - self.prob)))  # 计算应该保留的token数量
        # 生成随机排序的token索引，保留前num_keep个token
        keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
        if self.ordered:
            # 如果设置了ordered参数为True，则按顺序保持保留的token
            keep_indices = keep_indices.sort(dim=-1)[0]

        # 根据保留的token索引，从待处理的张量x中选择对应的token，得到Dropout后的张量
        x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))

        # 如果提取了前缀token，则将其与Dropout后的张量拼接在一起
        if prefix_tokens is not None:
            x = torch.cat((prefix_tokens, x), dim=1)

        # 如果设置了return_indices参数为True，则返回Dropout后的张量和保留的token索引
        if self.return_indices:
            return x, keep_indices
        return x  # 返回Dropout后的张量


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
'''
def _trunc_normal_(tensor, mean, std, a, b):
    """
    从PyTorch官方master中复制和粘贴的，直到它在几个官方版本中出现为止。
    方法基于 https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    Args:
        tensor (torch.Tensor): 需要初始化的张量。
        mean (float): 正态分布的均值。
        std (float): 正态分布的标准差。
        a (float): 截断范围的下限。
        b (float): 截断范围的上限。
    Returns:
        torch.Tensor: 初始化后的张量。
    """
    def norm_cdf(x):
        # 计算标准正态分布的累积分布函数
        return (1. + math.erf(x / math.sqrt(2.))) / 2.
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        # 如果均值超出了截断范围的2倍标准差，则发出警告
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    # 根据截断正态分布的累积分布函数计算上下限的cdf值
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)
    # 使用截断均匀分布填充张量，并将其范围转换为[2l-1, 2u-1]。
    tensor.uniform_(2 * l - 1, 2 * u - 1)
    # 使用正态分布的逆累积分布函数转换为截断标准正态分布
    tensor.erfinv_()
    # 转换为真实的均值和标准差
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)
    # 限制范围以确保在正确的范围内
    tensor.clamp_(min=a, max=b)
    return tensor



'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
'''
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


def named_apply(
        fn: Callable,
        module: nn.Module, name='',
        depth_first: bool = True,
        include_root: bool = False,
) -> nn.Module:
    """
    对模块的所有子模块递归应用指定函数。
    Args:
        fn (Callable): 要应用的函数。
        module (nn.Module): 要遍历的模块。
        name (str): 当前模块的名称。
        depth_first (bool): 是否按深度优先顺序应用函数。
        include_root (bool): 是否包含根模块。
    Returns:
        nn.Module: 应用函数后的模块。
    """
    if not depth_first and include_root:
        # 如果不是深度优先且包含根模块，则直接应用函数于根模块
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        # 对每个子模块递归调用named_apply
        child_name = '.'.join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        # 如果是深度优先且包含根模块，则应用函数于根模块
        fn(module=module, name=name)
    return module


def init_weights_vit_timm(module: nn.Module, name: str = ''):
    """
    ViT权重初始化，原始timm实现（用于可重现性）。
    Args:
        module (nn.Module): 要初始化权重的模块。
        name (str): 当前模块的名称。
    """
    if isinstance(module, nn.Linear):
        # 如果是线性层，使用截断正态分布初始化权重，标准差为0.02
        trunc_normal_(module.weight, std=.02)
        if module.bias is not None:
            # 如果有偏置项，将其初始化为零
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        # 如果模块有init_weights方法，调用该方法进行初始化
        module.init_weights()

def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
    """
    ViT权重初始化，与JAX（Flax）实现匹配。
    Args:
        module (nn.Module): 要初始化权重的模块。
        name (str): 当前模块的名称。
        head_bias (float): 头部模块偏置的初始值。
    """
    if isinstance(module, nn.Linear):
        if name.startswith('head'):
            # 如果是头部模块，将权重初始化为零，偏置初始化为head_bias
            nn.init.zeros_(module.weight)
            nn.init.constant_(module.bias, head_bias)
        else:
            # 否则使用xavier_uniform_初始化权重，偏置根据名称初始化
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                if 'mlp' in name:
                    # 如果是MLP模块，偏置初始化为正态分布，标准差为1e-6
                    nn.init.normal_(module.bias, std=1e-6)
                else:
                    # 否则偏置初始化为零
                    nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        # 对于卷积层，使用lecun_normal_初始化权重，偏置初始化为零
        lecun_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        # 如果模块有init_weights方法，调用该方法进行初始化
        module.init_weights()


def lecun_normal_(tensor):
    """
        使用LeCun初始化方法对张量进行初始化。
    """
    variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
    """
        使用方差缩放初始化方法对张量进行初始化。
        Args:
            tensor: 要初始化的张量。
            scale (float): 缩放因子，默认为1.0。
            mode (str): 初始化模式，可选值为'fan_in'、'fan_out'或'fan_avg'，默认为'fan_in'。
            distribution (str): 初始化分布，可选值为'normal'、'truncated_normal'或'uniform'，默认为'normal'。
        """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == 'fan_in':
        denom = fan_in
    elif mode == 'fan_out':
        denom = fan_out
    elif mode == 'fan_avg':
        denom = (fan_in + fan_out) / 2

    variance = scale / denom

    if distribution == "truncated_normal":
        # constant is stddev of standard normal truncated to (-2, 2)
        # 使用截断正态分布进行初始化
        trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
    elif distribution == "normal":
        # 使用正态分布进行初始化
        with torch.no_grad():
            tensor.normal_(std=math.sqrt(variance))
    elif distribution == "uniform":
        # 使用均匀分布进行初始化
        bound = math.sqrt(3 * variance)
        with torch.no_grad():
            tensor.uniform_(-bound, bound)
    else:
        raise ValueError(f"invalid distribution {distribution}")
    

def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsquently scaled and shifted by the mean and std args.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        _trunc_normal_(tensor, 0, 1.0, a, b)
        tensor.mul_(std).add_(mean)
    return tensor


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''


def init_weights_vit_moco(module: nn.Module, name: str = ''):
    """
    ViT权重初始化，与moco-v3实现相匹配，但去除了固定的PatchEmbed部分。
    Args:
        module (nn.Module): 要初始化的模型或模块。
        name (str): 模块名称，可选，默认为空字符串。
    """
    if isinstance(module, nn.Linear):
        if 'qkv' in name:
            # 分别处理Q、K、V的权重
            val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
            nn.init.uniform_(module.weight, -val, val)
        else:
            # 使用Xavier均匀分布初始化线性层的权重
            nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            # 将偏置项初始化为零
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        # 调用模块的自定义权重初始化方法（如果有的话）
        module.init_weights()


'''
Taken from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
'''
def get_init_weights_vit(mode='jax', head_bias: float = 0.):
    """
    获取用于ViT模型的初始化权重函数，根据模式选择合适的实现。
    Args:
        mode (str): 初始化模式，可选 'jax', 'moco'，默认为 'jax'。
        head_bias (float): 头部偏置值，用于某些初始化模式，默认为 0.。
    Returns:
        Callable: 返回一个可调用的初始化权重函数。
    """
    if 'jax' in mode:
        # 如果模式包含 'jax'，返回与JAX实现相匹配的初始化函数
        return partial(init_weights_vit_jax, head_bias=head_bias)
    elif 'moco' in mode:
        # 如果模式包含 'moco'，返回与moco-v3实现相匹配的初始化函数
        return init_weights_vit_moco
    else:
        # 默认情况下返回与timm实现相匹配的初始化函数
        return init_weights_vit_timm
