import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


class ChannelPriorMixer(nn.Module):
    """Estimate per-channel priors from frequency responses and mix them."""

    def __init__(
        self,
        c_in: int,
        d_model: int,
        topk: int = 16,
        temperature: float = 1.0,
        mixing_strength: float = 0.1,
        diag_bias: float = 0.2,
        use_phi: bool = False,
        eps: float = 1e-8,
    ) -> None:
        super().__init__()
        self.c_in = c_in
        self.d_model = d_model
        self.topk = topk
        self.temperature = temperature
        self.mixing_strength = mixing_strength
        self.diag_bias = diag_bias
        self.use_phi = use_phi
        self.eps = eps
        self.alpha = nn.Parameter(torch.tensor(1.0))
        self.beta = nn.Parameter(torch.tensor(0.5))

    @staticmethod
    def _align_to_length(x: torch.Tensor, length: int) -> torch.Tensor:
        """Resize tensors along the frequency dimension via interpolation."""
        batch_size, freq_len, hidden_dim = x.shape
        if freq_len == length:
            return x
        return F.interpolate(x.transpose(1, 2), size=length, mode='linear', align_corners=False).transpose(1, 2)

    def _build_freq_weights(
        self,
        mask_proj: Optional[torch.Tensor],
        weights_proj: Optional[torch.Tensor],
        freq_len: int,
    ) -> torch.Tensor:
        """Combine mask/weight projections into a unified frequency prior."""
        device = mask_proj.device if mask_proj is not None else (
            weights_proj.device if weights_proj is not None else 'cpu'
        )
        if mask_proj is None and weights_proj is None:
            return torch.ones(1, freq_len, device=device)

        components = []
        if mask_proj is not None:
            components.append(mask_proj.real.mean(
                dim=-1) if torch.is_complex(mask_proj) else mask_proj.mean(dim=-1))
        if weights_proj is not None:
            weights_real = weights_proj.real if torch.is_complex(
                weights_proj) else weights_proj
            components.append(weights_real.mean(dim=-1))

        aligned_components = []
        for comp in components:
            if comp.shape[1] != freq_len:
                comp = F.interpolate(comp.unsqueeze(
                    1), size=freq_len, mode='linear', align_corners=False).squeeze(1)
            aligned_components.append(comp)
        weights = torch.stack(aligned_components, dim=0).sum(dim=0)
        return torch.sigmoid(weights)

    def compute_priors(
        self,
        x: torch.Tensor,
        mask_proj: Optional[torch.Tensor] = None,
        weights_proj: Optional[torch.Tensor] = None,
        use_topk: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Derive correlation (`gamma`) and phase (`phi`) priors across channels."""
        batch_size, freq_len, num_channels = x.shape
        device = x.device

        weights = self._build_freq_weights(mask_proj, weights_proj, freq_len)
        if weights.shape[0] != batch_size:
            weights = weights.expand(batch_size, -1)
        weights = weights.clamp(0.0, 1.0)

        if use_topk and self.topk > 0 and self.topk < freq_len:
            topk_idx = torch.topk(weights, k=self.topk, dim=1).indices
            selector = torch.zeros(batch_size, freq_len, device=device)
            selector.scatter_(1, topk_idx, 1.0)
            weight_mask = selector
        else:
            weight_mask = weights

        amplitude = x.abs() if torch.is_complex(x) else x.abs()
        weighted_amplitude = amplitude * weight_mask.unsqueeze(-1)

        channel_amplitude = weighted_amplitude.transpose(1, 2)
        centered = channel_amplitude - \
            channel_amplitude.mean(dim=-1, keepdim=True)
        denom = (centered.pow(2).sum(dim=-1, keepdim=True).sqrt() + self.eps)
        normalized = centered / denom
        gamma = torch.matmul(normalized, normalized.transpose(-1, -2))

        phase = torch.angle(x) if torch.is_complex(
            x) else torch.zeros_like(amplitude)
        sin_stats = (torch.sin(phase) * weight_mask.unsqueeze(-1)).sum(dim=1)
        cos_stats = (torch.cos(phase) * weight_mask.unsqueeze(-1)).sum(dim=1)
        phi = torch.einsum('bi,bj->bij', sin_stats, cos_stats) - \
            torch.einsum('bi,bj->bij', cos_stats, sin_stats)
        max_abs = phi.abs().amax(dim=(-2, -1), keepdim=True) + self.eps
        phi = phi / max_abs

        return gamma, phi, weight_mask

    def compute_mixing(self, gamma: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
        """Convert priors into a stochastic mixing matrix with temperature control."""
        logits = self.alpha * gamma + \
            (self.beta * phi if self.use_phi else 0.0)
        mixing = F.softmax(logits / max(self.temperature, self.eps), dim=-1)
        if self.diag_bias > 0:
            batch_size, num_channels, _ = mixing.shape
            identity = torch.eye(num_channels, device=mixing.device, dtype=mixing.dtype).unsqueeze(
                0).expand(batch_size, -1, -1)
            mixing = (1 - self.diag_bias) * mixing + self.diag_bias * identity
        return mixing

    def apply_mixing(self, mixing: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """Blend original activations with mixed ones controlled by strength."""
        if torch.is_complex(x):
            mixed_real = torch.einsum('bij,bfj->bfi', mixing, x.real)
            mixed_imag = torch.einsum('bij,bfj->bfi', mixing, x.imag)
            mixed = torch.complex(mixed_real, mixed_imag)
        else:
            mixed = torch.einsum('bij,bfj->bfi', mixing, x)

        strength = float(self.mixing_strength)
        if strength <= 0:
            return x
        return (1.0 - strength) * x + strength * mixed

    def compute_channel_gate(
        self,
        x: torch.Tensor,
        w_eff: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Summarize spectrum energy into [0,1] gates per channel."""
        amplitude = x.abs() if torch.is_complex(x) else x.abs()
        if w_eff is not None:
            amplitude = amplitude * w_eff.unsqueeze(-1)
        summary = amplitude.mean(dim=1)
        summary = (summary - summary.min(dim=-1, keepdim=True).values) / (
            summary.max(dim=-1, keepdim=True).values -
            summary.min(dim=-1, keepdim=True).values + self.eps
        )
        return summary

    def forward(
        self,
        x: torch.Tensor,
        mask_proj: Optional[torch.Tensor] = None,
        weights_proj: Optional[torch.Tensor] = None,
        use_topk: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Apply mixing to the spectrum and return both priors and weights."""
        gamma, phi, weight_mask = self.compute_priors(
            x, mask_proj, weights_proj, use_topk=use_topk)
        mixing = self.compute_mixing(gamma, phi)
        mixed = self.apply_mixing(mixing, x)
        return mixed, mixing, gamma, phi
