import math
from typing import Optional

import torch
import torch.nn as nn

from layers.complex_func import ComplexLayerNorm, ComplexLinear


class AdaptiveFeatureFusion(nn.Module):

    """Fuse raw frequency responses with DynFBD tokens using complex attention."""

    def __init__(
        self,
        d_model: int,
        n_heads: int = 8,
        dropout: float = 0.1,
        fusion_strategy: str = 'additive',
        alpha: float = 0.7,
        use_channel_mha: bool = False,
        use_feature_axis_attention: bool = False,
    ) -> None:
        super().__init__()
        if n_heads <= 0:
            raise ValueError('n_heads must be positive')
        if d_model % n_heads != 0:
            raise ValueError(
                'd_model must be divisible by n_heads for cross-attention fusion')

        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5

        self.query_proj = nn.Linear(2 * d_model, n_heads * self.head_dim)
        self.key_proj = nn.Linear(2 * d_model, n_heads * self.head_dim)
        self.value_proj = ComplexLinear(d_model, d_model)
        self.output_proj = ComplexLinear(d_model, d_model)
        self.layer_norm = ComplexLayerNorm(d_model)
        self.attn_dropout = nn.Dropout(dropout)
        self.residual_dropout_p = dropout

        self.token_value_proj: Optional[ComplexLinear] = None
        self.token_value_in_dim: Optional[int] = None
        self.mask_affine: Optional[nn.Linear] = None
        self.mask_affine_in_dim: Optional[int] = None
        self.weights_affine: Optional[nn.Linear] = None
        self.weights_affine_in_dim: Optional[int] = None

        self.last_attention_weights: Optional[torch.Tensor] = None
        self.last_token_focus: Optional[torch.Tensor] = None
        self.last_mask_gate: Optional[torch.Tensor] = None
        self.last_weights_gate: Optional[torch.Tensor] = None
        self.last_mask_bias: Optional[torch.Tensor] = None
        self.last_weights_bias: Optional[torch.Tensor] = None

        self.fusion_strategy = fusion_strategy
        self.alpha = alpha
        self.use_channel_mha = use_channel_mha
        self.use_feature_axis_attention = use_feature_axis_attention

    @staticmethod
    def _align_to_length(x: torch.Tensor, target_len: int) -> torch.Tensor:
        """Pad or crop sequences so token-dependent ops stay length-aligned."""
        if x.shape[1] == target_len:
            return x
        if x.shape[1] > target_len:
            return x[:, :target_len, :]
        pad_shape = (x.shape[0], target_len - x.shape[1], x.shape[2])
        padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
        return torch.cat([x, padding], dim=1)

    @staticmethod
    def _to_real_features(x: torch.Tensor) -> torch.Tensor:
        """Concatenate real/imag parts when a real-valued affine is expected."""
        return torch.cat([x.real, x.imag], dim=-1) if torch.is_complex(x) else x

    def _ensure_token_value_proj(self, in_dim: int, device: torch.device) -> ComplexLinear:
        """Instantiate or update the projection for incoming DynFBD tokens."""
        if self.token_value_proj is None or self.token_value_in_dim != in_dim:
            self.token_value_proj = ComplexLinear(
                in_dim, self.d_model).to(device)
            self.token_value_in_dim = in_dim
        return self.token_value_proj

    def _ensure_mask_affine(self, in_dim: int, device: torch.device) -> nn.Linear:
        """Create the learnable transform that maps masks into bias/gates."""
        if self.mask_affine is None or self.mask_affine_in_dim != in_dim:
            self.mask_affine = nn.Linear(in_dim, self.n_heads * 2).to(device)
            self.mask_affine_in_dim = in_dim
        return self.mask_affine

    def _ensure_weights_affine(self, in_dim: int, device: torch.device) -> nn.Linear:
        """Create the learnable transform that maps selector weights to modulations."""
        if self.weights_affine is None or self.weights_affine_in_dim != in_dim:
            self.weights_affine = nn.Linear(
                in_dim, self.n_heads * 2).to(device)
            self.weights_affine_in_dim = in_dim
        return self.weights_affine

    def _complex_dropout(self, x: torch.Tensor, p: float) -> torch.Tensor:
        """Drop real/imag parts independently to regularize complex features."""
        if not self.training or p <= 0.0:
            return x
        keep_prob = 1.0 - p
        mask = torch.empty_like(x.real).bernoulli_(keep_prob)
        scale = 1.0 / keep_prob
        real = x.real * mask * scale
        imag = x.imag * mask * scale
        return torch.complex(real, imag)

    def forward(
        self,
        raw_fft: torch.Tensor,
        tokens: Optional[torch.Tensor] = None,
        mask_proj: Optional[torch.Tensor] = None,
        weights_proj: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Fuse DynFBD tokens into the raw FFT representation via attention."""

        if tokens is None:
            self.last_attention_weights = None
            self.last_token_focus = None
            self.last_mask_gate = None
            self.last_weights_gate = None
            self.last_mask_bias = None
            self.last_weights_bias = None
            return raw_fft

        B, _, C = raw_fft.shape
        raw_query = raw_fft.transpose(1, 2)

        token_len = tokens.shape[1]
        device = tokens.device

        token_value_proj = self._ensure_token_value_proj(
            tokens.shape[-1], device)
        token_values = token_value_proj(tokens)

        mask_bias_term: Optional[torch.Tensor] = None
        weights_bias_term: Optional[torch.Tensor] = None
        mask_gate = torch.ones(B, token_len, self.n_heads,
                               device=device, dtype=torch.float32)
        weights_gate = torch.ones_like(mask_gate)

        if mask_proj is not None:
            mask_aligned = self._align_to_length(mask_proj, token_len)
            mask_input = self._to_real_features(mask_aligned)
            mask_affine = self._ensure_mask_affine(
                mask_input.shape[-1], mask_input.device)
            mask_out = mask_affine(mask_input)
            mask_bias_term, mask_gate_term = mask_out.chunk(2, dim=-1)
            mask_gate = torch.sigmoid(mask_gate_term)

        if weights_proj is not None:
            weights_aligned = self._align_to_length(weights_proj, token_len)
            weights_input = self._to_real_features(weights_aligned)
            weights_affine = self._ensure_weights_affine(
                weights_input.shape[-1], weights_input.device)
            weights_out = weights_affine(weights_input)
            weights_bias_term, weights_gate_term = weights_out.chunk(2, dim=-1)
            weights_gate = torch.sigmoid(weights_gate_term)

        combined_gate = mask_gate * weights_gate

        bias_terms = []
        if mask_bias_term is not None:
            bias_terms.append(mask_bias_term)
        if weights_bias_term is not None:
            bias_terms.append(weights_bias_term)
        bias = sum(bias_terms) if bias_terms else None

        query_input = torch.cat([raw_query.real, raw_query.imag], dim=-1)
        key_input = torch.cat([token_values.real, token_values.imag], dim=-1)

        query = self.query_proj(query_input).view(
            B, C, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        key = self.key_proj(key_input).view(
            B, token_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        value = self.value_proj(token_values).view(
            B, token_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        gate_for_values = combined_gate.permute(0, 2, 1).unsqueeze(-1)
        value_real = value.real * gate_for_values
        value_imag = value.imag * gate_for_values

        attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        if bias is not None:
            bias_for_scores = bias.permute(0, 2, 1).unsqueeze(2)
            attn_scores = attn_scores + bias_for_scores

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        context_real = torch.matmul(attn_weights, value_real)
        context_imag = torch.matmul(attn_weights, value_imag)
        context_real = context_real.permute(
            0, 2, 1, 3).reshape(B, C, self.d_model)
        context_imag = context_imag.permute(
            0, 2, 1, 3).reshape(B, C, self.d_model)
        context = torch.complex(context_real, context_imag)

        fused = self.output_proj(context)
        fused = self._complex_dropout(fused, self.residual_dropout_p)
        output = self.layer_norm(fused + raw_query)

        self.last_attention_weights = attn_weights.detach()
        if token_len > 1:
            entropy = -(attn_weights.mean(dim=1).clamp_min(1e-8) *
                        torch.log(attn_weights.mean(dim=1).clamp_min(1e-8))).sum(dim=-1)
            entropy = entropy / math.log(token_len)
            self.last_token_focus = (1.0 - entropy).clamp(0.0, 1.0).detach()
        else:
            self.last_token_focus = torch.ones(B, C, device=device).detach()
        self.last_mask_gate = mask_gate.detach()
        self.last_weights_gate = weights_gate.detach()
        self.last_mask_bias = mask_bias_term.detach(
        ) if mask_bias_term is not None else None
        self.last_weights_bias = weights_bias_term.detach(
        ) if weights_bias_term is not None else None

        return output.transpose(1, 2)

    def get_interpretability_info(self) -> dict:
        """Expose cached diagnostics for external visualization hooks."""
        return {
            'attention_weights': self.last_attention_weights,
            'token_focus': self.last_token_focus,
            'mask_gate': self.last_mask_gate,
            'weights_gate': self.last_weights_gate,
            'mask_bias': self.last_mask_bias,
            'weights_bias': self.last_weights_bias,
        }
