# adapter.py

import torch
import torch.nn as nn


class Adapter(nn.Module):
    def __init__(self, input_dim, bottleneck_dim=64, dropout=0.1):
        """
        Adapter模块
        :param input_dim: 输入特征的维度，即Transformer Block中的embed_dim
        :param bottleneck_dim: 瓶颈层的维度，一个远小于input_dim的超参数
        :param dropout: Dropout率
        """
        super().__init__()
        self.down_proj = nn.Linear(input_dim, bottleneck_dim)
        self.activation = nn.GELU()  # GELU是Transformer中常用的激活函数
        self.up_proj = nn.Linear(bottleneck_dim, input_dim)
        self.dropout = nn.Dropout(dropout)

        # 初始化：down_proj正常初始化，up_proj初始化为零
        # 这样在训练初期，Adapter的输出为0，相当于一个恒等映射，保证了训练的稳定性
        nn.init.zeros_(self.up_proj.weight)
        nn.init.zeros_(self.up_proj.bias)

    def forward(self, x):
        # 记录原始输入，用于残差连接
        identity = x

        # 通过Adapter模块
        x = self.down_proj(x)
        x = self.activation(x)
        x = self.up_proj(x)
        x = self.dropout(x)

        # 添加残差连接
        return identity + x
