import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        if cls_num_tensor is None:
            cls_num_tensor = torch.ones(100)
        self.register_buffer('cls_num', cls_num_tensor.float())

        # Imbalance-aware gating network: compact descriptor is pooled from class counts
        self.imbalance_mlp = nn.Sequential(
            nn.Linear(4, 8),
            nn.ReLU(inplace=True),
            nn.Linear(8, in_dim),
            nn.Tanh()
        )

        self.norm = nn.LayerNorm(in_dim)
        if in_dim != out_dim:
            self.proj = nn.Linear(in_dim, out_dim)
        else:
            self.proj = None
        self.act = nn.GELU()

    def forward(self, feats):
        # feats: [N, in_dim]
        x = self.norm(feats)

        # Compute a compact 4-dim descriptor from class distribution without gradients
        with torch.no_grad():
            mean = self.cls_num.mean()
            maxv = self.cls_num.max()
            minv = self.cls_num.min()
            ratio = maxv / (mean + 1e-6)
            imb_vec = torch.stack([mean, maxv, minv, ratio], dim=0).to(feats.device)

        gate = self.imbalance_mlp(imb_vec.view(1, -1))  # [1, in_dim]
        gate = gate.view(-1)  # [in_dim]

        refined = x * (1.0 + gate)  # per-feature modulation

        if self.proj is not None:
            refined = self.proj(refined)

        refined = self.act(refined)
        return refined