class Partial_Channel_Filtering(nn.Module):

    def __init__(self, dim, pdim):
        super().__init__()
        self.dim = dim
        self.pdim = pdim
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        B, C, H, W = x.shape
        g1 = self.global_pool(x)
        g1 = g1.view(B, C)
        g1 = torch.abs(g1)
        sorted_indices = torch.argsort(g1, dim=1, descending=True)
        top_half = sorted_indices[:, :self.pdim]
        bottom_half = sorted_indices[:, self.pdim:]

        x1 = torch.gather(x, 1, top_half.unsqueeze(-1).unsqueeze(-1).expand(B, -1, H, W))
        x2 = torch.gather(x, 1, bottom_half.unsqueeze(-1).unsqueeze(-1).expand(B, -1, H, W))

        return x1, x2