import torch
import torch.nn as nn
import torch.nn.functional as F


class GaussDynFBD(nn.Module):

    """Gaussian dynamic frequency band decomposition for rFFT features."""

    EPS_SIGMA = 1e-6
    EPS_NORM = 1e-9

    def __init__(self, fft_len: int, bands: int = 3,
                 per_channel: bool = False, init_choice: str = 'linear') -> None:
        super().__init__()
        self.bands = bands
        self.fft_len = fft_len
        self.per_channel = per_channel

        self.mu = nn.Parameter(torch.zeros(1, bands))
        self.log_sigma = nn.Parameter(torch.zeros(1, bands))

        self.register_buffer('freq_grid', torch.linspace(0.0, 1.0, fft_len))

        self._init_parameters(init_choice)

    def _init_parameters(self, init_choice: str):
        """Initialize Gaussian centers and spreads following the chosen schedule."""
        with torch.no_grad():
            if init_choice == 'linear':
                mu_init = torch.linspace(0.15, 0.85, self.bands)
            elif init_choice == 'log':
                mu_init = torch.logspace(-1, 0, self.bands, base=10.0)
                mu_init = mu_init / mu_init.max()
            else:
                tri = torch.tensor([0.15, 0.45, 0.75])
                mu_init = tri[:self.bands]

            self.mu.copy_(mu_init.unsqueeze(0))
            self.log_sigma.fill_(torch.log(torch.tensor(0.15)))

    def _expand_parameters(self, C: int, device):
        """Broadcast learned Gaussian parameters to match channel dimensions."""
        mu = self.mu.to(device)
        sigma = F.softplus(self.log_sigma) + self.EPS_SIGMA
        sigma = sigma.to(device)

        if self.per_channel:
            mu = mu.expand(C, -1)
            sigma = sigma.expand(C, -1)
        else:
            mu = mu.repeat(C, 1)
            sigma = sigma.repeat(C, 1)

        return mu.view(C, self.bands, 1), sigma.view(C, self.bands, 1)

    def _compute_gaussian_weights(self, mu, sigma, freq):
        """Compute normalized Gaussian weights over the frequency grid."""
        dist = (freq - mu) ** 2 / (2.0 * sigma ** 2)
        weights = torch.exp(-dist)

        return weights / (weights.sum(dim=1, keepdim=True) + self.EPS_NORM)

    def forward(self, x_fft: torch.Tensor):
        """Split spectra into multiple soft bands and return band weights."""

        assert x_fft.dim() == 3, "Input x_fft should be a 3D (B,C,F) complex tensor"
        B, C, freq_len = x_fft.shape
        device = x_fft.device

        mu, sigma = self._expand_parameters(C, device)

        freq = self.freq_grid.to(device).view(1, 1, freq_len)

        weights = self._compute_gaussian_weights(mu, sigma, freq)

        weights = weights.unsqueeze(0).expand(B, -1, -1, -1)
        x_split = weights * x_fft.unsqueeze(2)

        return x_split, weights
