import torch
import torch.nn as nn
import math

class Block(nn.Module):
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim
        self.dim = in_dim
        self.out_dim = out_dim
        self.use_layernorm = True
        self.has_buffer = False
        if cls_num_tensor is not None:
            self.register_buffer('cls_num_tensor', cls_num_tensor)
            self.has_buffer = True
        # Minimal bottleneck for FiLM (keeps parameters small)
        proj_dim = 32  # small bottleneck to satisfy parameter budget
        self.input_proj = None
        if in_dim != self.dim:
            self.input_proj = nn.Linear(in_dim, self.dim)
        self.relu = nn.GELU()
        self.W1 = nn.Linear(self.dim, proj_dim)
        self.gamma_head = nn.Linear(proj_dim, self.dim)
        self.beta_head = nn.Linear(proj_dim, self.dim)
        self.attn_gate = nn.Linear(self.dim, 1)
        self.norm = nn.LayerNorm(self.dim)
        self.dropout = nn.Dropout(p=0.0)
        self.output = nn.Linear(self.dim, self.in_dim)

    def forward(self, feats):
        x = feats
        if self.input_proj is not None:
            x = self.input_proj(feats)
        proj = self.relu(self.W1(x))
        gamma = torch.sigmoid(self.gamma_head(proj))  # [N, dim], in [0,1]
        beta = self.beta_head(proj)                   # [N, dim]
        gate_scalar = 1.0
        if self.has_buffer:
            try:
                counts = self.cls_num_tensor.float()
                if counts.numel() > 0:
                    cmin = counts.min()
                    cmax = counts.max()
                    eps = 1e-6
                    ratio = (cmax / (cmin + eps)).item()
                    log_ratio = math.log(ratio + 1e-6)
                    gate_scalar = 1.0 + 0.08 * log_ratio
            except Exception:
                gate_scalar = 1.0
        # FiLM refinement with gating
        refined = x * (1.0 + gamma) + beta
        refined = refined * gate_scalar
        # Gate-attention to emphasize tail-focused refinement
        attn = torch.sigmoid(self.attn_gate(x))
        refined = attn * refined + (1.0 - attn) * x
        if self.use_layernorm:
            refined = self.norm(refined)
        if self.dropout.p > 0.0:
            refined = self.dropout(refined)
        # Project back to input dimension to preserve seamless integration
        out = self.output(refined)
        return out