import torch
import torch.nn as nn
from layers.complex_func import ComplexLinear


class FrequencySelector(nn.Module):

    """Score DynFBD tokens and down-weight less informative frequencies."""

    def __init__(self, hidden_dim, normalize_weights=True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.normalize_weights = normalize_weights

        self.weight_layer = ComplexLinear(hidden_dim, 1)
        self.output_layer = ComplexLinear(hidden_dim, hidden_dim)

    def forward(self, x):
        """Return reweighted tokens together with their normalized scores."""

        batch_size, num_tokens, num_channels = x.shape

        raw_weights = self.weight_layer(x).abs().squeeze(-1)
        weights = torch.sigmoid(raw_weights).mean(dim=0)

        if self.normalize_weights:
            weights = weights / (weights.sum() + 1e-9)

        x_weighted = self.output_layer(x) * weights.view(1, num_tokens, 1)

        return x_weighted, weights
