from mmseg.registry import MODELS
import torch.nn as nn
import torch.nn.functional as F
import torch

@MODELS.register_module()
class Convpass(nn.Module):
    def __init__(self, input_dim, middle_dim=192):
        super(Convpass, self).__init__()
        self.adapter_down = nn.Linear(input_dim, middle_dim, bias=False)
        self.adapter_up = nn.Linear(middle_dim, input_dim, bias=False)
        self.conv2 = nn.Conv2d(middle_dim, middle_dim, kernel_size=3, stride=1, padding=1, groups=middle_dim, bias=False)
        self.act = F.gelu
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, hw_shape):
        B, N, C = x.shape  # (batch, seq_len, embed_dim)
        H, W = hw_shape
        x_down = self.adapter_down(x)  # (B, N, reduced_dim)
        x_down = self.act(x_down)
        x_patch = x_down[:, 1:].reshape(B, H, W, -1).permute(0, 3, 1, 2)  
        x_patch = self.conv2(x_patch)
        x_patch = x_patch.permute(0, 2, 3, 1).reshape(B, H * W, -1) 
        x_cls = x_down[:, :1]  # (B, 1, reduced_dim)
        x_down = torch.cat([x_cls, x_patch], dim=1)  # (B, N, reduced_dim)
        x_down = self.act(x_down)
        x_down = self.dropout(x_down)
        x_up = self.adapter_up(x_down)  # (B, N, embed_dim)

        return x + x_up 
