"""
Task heads for prediction from either backbone or domain-invariant features.

Each head is a compact MLP with two hidden layers [64, 32], ReLU activations,
and dropout. The last layer emits raw logits so the loss outside can select
the appropriate objective (e.g., BCEWithLogits for multi-label).
"""


from typing import Optional

import torch
import torch.nn as nn


class LabelHead(nn.Module):
    """
    Lightweight prediction head (shared spec for backbone & invariant paths).
    """

    def __init__(self, in_dim: int, out_dim: int, hidden1: int = 64, hidden2: int = 32, dropout: float = 0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden1),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden1, hidden2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden2, out_dim),   # logits (no activation)
        )
        # He/Xavier initialization is already fine via PyTorch defaults; keep explicit for readability.
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : [B, D]

        Returns
        -------
        logits : [B, out_dim]
        """
        return self.net(x)
