import torch
import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    def __init__(self, in_dim=64, out_dim=64, cls_num_tensor=None):
        super(Block, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        if cls_num_tensor is not None:
            self.register_buffer("cls_counts", cls_num_tensor.float())
        else:
            self.register_buffer("cls_counts", torch.zeros(1))
        self.delta = nn.Parameter(torch.randn(in_dim) * 0.01)
        self.ln = nn.LayerNorm(in_dim)
        self.se_fc1 = nn.Linear(in_dim, 8, bias=False)
        self.se_fc2 = nn.Linear(8, in_dim, bias=False)
        self.eps = 1e-6
        self.use_attention = False  # optional lightweight 2-head attention gate
        if in_dim != out_dim:
            self.adapter_down = nn.Linear(in_dim, 4, bias=False)
            self.adapter_up = nn.Linear(4, out_dim, bias=False)
            self.adapter_ln = nn.LayerNorm(out_dim)
        else:
            self.adapter_down = None
            self.adapter_up = None
            self.adapter_ln = None
        self.att_g1 = nn.Parameter(torch.randn(in_dim) * 0.01)
        self.att_g2 = nn.Parameter(torch.randn(in_dim) * 0.01)
        self.att_scalar1 = nn.Parameter(torch.tensor(0.1))
        self.att_scalar2 = nn.Parameter(torch.tensor(0.1))
        self.dropout = nn.Dropout(p=0.15)

    def forward(self, feats):
        x = feats
        se_in = x.mean(dim=0, keepdim=True)
        s = self.se_fc1(se_in)
        s = F.relu(s)
        s = self.se_fc2(s)
        s = torch.sigmoid(s)
        x = x * (1.0 + s)
        dot = x @ self.delta
        gate_scalar = 0.0
        if self.cls_counts is not None and self.cls_counts.numel() > 0:
            with torch.no_grad():
                counts = self.cls_counts.float()
                inv_sqrt = 1.0 / (torch.sqrt(counts) + self.eps)
                gate_scalar = float(inv_sqrt.mean().item())
        gate = torch.sigmoid(dot * gate_scalar)
        x = x + gate[:, None] * self.delta
        x = self.ln(x)
        if self.adapter_down is not None:
            z = self.adapter_down(x)
            z = F.relu(z)
            z = self.adapter_up(z)
            z = self.adapter_ln(z)
            x = x + z
        if getattr(self, "use_attention", False):
            attn1 = torch.sigmoid(feats @ self.att_g1)
            attn2 = torch.sigmoid(feats @ self.att_g2)
            attn1 = self.dropout(attn1)
            attn2 = self.dropout(attn2)
            gate2 = 1.0 + attn1 * self.att_scalar1 + attn2 * self.att_scalar2
            x = x * gate2[:, None]
        return x