from __future__ import annotations

import torch
import torch.nn as nn

from layers.gauss_dynfbd import GaussDynFBD
from layers.frequency_selector import FrequencySelector
from layers.complex_func import ComplexLinear


class FFTProcessor(nn.Module):
    """Utility layer that converts real signals to rFFT frequency domain."""

    def forward(self, x_real: torch.Tensor) -> torch.Tensor:
        return torch.fft.rfft(x_real, dim=1)


class ProjectionHandler(nn.Module):
    """Turn DynFBD masks/weights into compact projections for attention."""

    def __init__(self, cfg):
        super().__init__()
        self.mask_proj_dim = getattr(cfg, 'mask_proj_dim', 16)
        self.weights_proj_dim = getattr(cfg, 'weights_proj_dim', 8)
        self.target_unit_k = getattr(cfg, 'target_unit_k', 128)
        self.enc_in = cfg.enc_in
        self.bands = 3

        self.mask_proj = nn.Linear(
            self.bands * self.enc_in, self.mask_proj_dim)
        self.weights_proj = ComplexLinear(
            self.target_unit_k, self.weights_proj_dim)

    def process_mask(self, mask, B, F, C):
        if mask is None:
            return None
        mask_reshaped = mask.permute(0, 3, 1, 2).reshape(B, F, C * self.bands)
        return self.mask_proj(mask_reshaped)

    def process_weights(self, weights, B, K):
        if weights is None:
            return None
        if not weights.dtype.is_complex:
            weights = torch.complex(weights, torch.zeros_like(weights))
        weights_expanded = weights.unsqueeze(0).repeat(B, 1)
        return self.weights_proj(weights_expanded).unsqueeze(1).repeat(1, K, 1)


class BandPreprocessor(nn.Module):

    """Prepare frequency tokens, masks and weights for FACT's encoder."""

    def __init__(self, cfg):
        super().__init__()
        self.use_dynfbd = getattr(cfg, 'use_dynfbd', True)
        self.use_fselector = getattr(cfg, 'use_fselector', True)
        self.target_unit_k = getattr(cfg, 'target_unit_k', 128)
        self.bands = 3
        self.enc_in = cfg.enc_in
        fft_len = cfg.d_model

        self.fft_processor = FFTProcessor()
        if not self.use_dynfbd:
            raise ValueError('Minimal release expects use_dynfbd=True')

        self.band_decomposer = GaussDynFBD(fft_len, bands=self.bands)
        setattr(self.band_decomposer, 'proj',
                ComplexLinear(fft_len, self.target_unit_k))

        self.frequency_selector = FrequencySelector(
            self.enc_in * self.bands, normalize_weights=True) if self.use_fselector else None
        self.projection_handler = ProjectionHandler(cfg)

    def _process_gaussian_decomposition(self, x_fft):
        x_split, mask = self.band_decomposer(x_fft.transpose(1, 2))
        B, C, bands, F = x_split.shape
        tokens_proj = self.band_decomposer.proj(
            x_split.permute(0, 2, 1, 3).reshape(B * bands, C, F))
        tokens = tokens_proj.reshape(B, bands, C, self.target_unit_k).permute(
            0, 3, 2, 1).reshape(B, self.target_unit_k, C * bands)
        return tokens, mask

    def _apply_frequency_selector(self, tokens):
        if self.frequency_selector is None:
            return tokens, None
        tokens, weights = self.frequency_selector(tokens)
        if weights.dtype.is_complex:
            weights = weights.abs()
        return tokens, weights

    def process_fft(self, x_fft: torch.Tensor):
        """Generate DynFBD tokens plus auxiliary projections from an rFFT."""
        B, F, C = x_fft.shape
        if not x_fft.dtype.is_complex:
            x_fft = torch.complex(x_fft, torch.zeros_like(x_fft))

        tokens, mask = self._process_gaussian_decomposition(x_fft)
        tokens, weights = self._apply_frequency_selector(tokens)

        mask_proj = self.projection_handler.process_mask(mask, B, F, C)
        weights_proj = self.projection_handler.process_weights(
            weights, B, self.target_unit_k)
        return tokens, mask_proj, weights_proj

    def forward(self, x_real: torch.Tensor):
        """Convenience wrapper that performs FFT then calls :meth:`process_fft`."""
        x_fft = self.fft_processor(x_real)
        return self.process_fft(x_fft)
