from mmseg.registry import MODELS
import torch.nn as nn

@MODELS.register_module()
class Adapter(nn.Module):
    def __init__(self, input_dim, middle_dim=192):
        super().__init__()
        self.adapter = nn.Sequential(
            nn.Linear(input_dim, middle_dim, bias=False),
            nn.ReLU(),
            nn.Linear(middle_dim, input_dim, bias=False)
        )

    def forward(self, x):
        return x + self.adapter(x)  
