from __future__ import annotations

import torch


class LinearHeadEstimator(torch.nn.Module):
    """A single linear layer over attention-head features.

    For Qwen3-VL-8B-Thinking: num_features = 32 layers x 36 heads = 1152.
    """

    def __init__(self, num_features: int):
        super().__init__()
        self.linear = torch.nn.Linear(num_features, 1, bias=False)
        torch.nn.init.constant_(self.linear.weight, 1.0 / float(num_features))

    @property
    def num_features(self) -> int:
        return int(self.linear.weight.shape[1])

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        # features: (batch, M, F) -> (batch, M)
        return self.linear(features)[:, :, 0]

    def finalize(self) -> None:
        # L1-normalize for interpretability.
        w = self.linear.weight.data
        self.linear.weight.data = w / (w.abs().sum() + 1e-12)
