from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from layers.complex_full_attention import (
    ComplexFullAttention,
    ComplexFullAttentionLayer,
    Encoder,
    EncoderLayer,
)
from layers.complex_func import ComplexLayerNorm, ComplexProjection
from layers.embedding import FourierInterpolatedFrequencyEmbedding
from layers.revin import RevIN, RevINComplex
from layers.adaptive_fusion import AdaptiveFeatureFusion
from layers.band_preprocessor import BandPreprocessor
from layers.channel_mixer import ChannelPriorMixer


class FACTModel(nn.Module):
    """Core FACT architecture for long-range multivariate forecasting."""

    def __init__(self, configs) -> None:
        super().__init__()
        self.config = configs
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.task_name = configs.task_name

        self.use_revin = bool(getattr(configs, 'use_revin', False))
        if self.use_revin:
            self.revin_layer = RevIN(
                num_features=configs.enc_in,
                eps=1e-5,
                affine=getattr(configs, 'revin_affine', True),
                subtract_last=getattr(configs, 'revin_subtract_last', False),
            )
            self.use_complex_revin = bool(
                getattr(configs, 'use_complex_revin', False))
            if self.use_complex_revin:
                self.revin_complex = RevINComplex(
                    num_features=configs.enc_in,
                    eps=1e-5,
                    affine=getattr(configs, 'revin_affine', True),
                    subtract_last=getattr(
                        configs, 'revin_subtract_last', False),
                )
        else:
            self.use_complex_revin = False

        self.use_norm = bool(getattr(configs, 'use_norm', 0))
        self.layer_norm = ComplexLayerNorm(configs.d_model)

        self.band_preprocessor = BandPreprocessor(configs) if getattr(
            configs, 'use_dynfbd', False) else None
        self.k_tokens = configs.target_unit_k if self.band_preprocessor is not None else self.seq_len // 2 + 1

        self.embedding = FourierInterpolatedFrequencyEmbedding(
            self.seq_len, configs.d_model, configs.enc_in)

        attention_layers = [
            EncoderLayer(
                ComplexFullAttentionLayer(
                    ComplexFullAttention(
                        mask_flag=True,
                        scale=None,
                        attention_dropout=configs.dropout,
                        output_attention=getattr(
                            configs, 'output_attention', False),
                    ),
                    configs.d_model,
                    configs.n_heads,
                ),
                configs.d_model,
                configs.d_ff,
                dropout=configs.dropout,
                activation=configs.activation,
            )
            for _ in range(configs.e_layers)
        ]
        self.encoder = Encoder(
            attention_layers, norm_layer=torch.nn.LayerNorm(configs.d_model))

        self.use_adaptive_fusion = bool(
            getattr(configs, 'use_adaptive_fusion', False))
        self.use_channel_mixing = bool(
            getattr(configs, 'use_channel_mixing', False))
        self.use_guided_gating = bool(
            getattr(configs, 'use_guided_gating', False))

        if self.use_adaptive_fusion or self.use_channel_mixing or self.use_guided_gating:
            self.channel_mixer = ChannelPriorMixer(
                c_in=configs.enc_in,
                d_model=configs.d_model,
                topk=getattr(configs, 'mixing_topk', 16),
                temperature=getattr(configs, 'mixing_temperature', 1.0),
                mixing_strength=getattr(configs, 'mixing_strength', 0.1),
                diag_bias=getattr(configs, 'mixing_diag_bias', 0.2),
                use_phi=getattr(configs, 'use_mixing_phi', False),
            )
        else:
            self.channel_mixer = None

        if self.use_adaptive_fusion:
            self.adaptive_fusion = AdaptiveFeatureFusion(
                d_model=configs.d_model,
                n_heads=getattr(configs, 'fusion_heads', 8),
                dropout=configs.dropout,
                fusion_strategy=getattr(
                    configs, 'fusion_strategy', 'additive'),
                alpha=getattr(configs, 'fusion_alpha', 0.7),
                use_channel_mha=getattr(configs, 'use_channel_mha', False),
                use_feature_axis_attention=getattr(
                    configs, 'use_feature_axis_attention', False),
            )
        else:
            self.adaptive_fusion = None

        self.gate_bias = getattr(configs, 'gate_bias', 0.5)
        self.gate_scale = getattr(configs, 'gate_scale', 0.5)

        self.projection = ComplexProjection(configs.d_model, configs.pred_len)

    def _norm(self, x: torch.Tensor, forward: bool = True, stats=None):
        """Apply reversible per-instance normalization when enabled."""
        if forward:
            means = x.mean(dim=1, keepdim=True).detach()
            stdev = torch.sqrt(
                torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x = (x - means) / stdev
            return x, means, stdev
        else:
            means, stdev = stats
            x = x * stdev.repeat(1, x.shape[1], 1)
            x = x + means.repeat(1, x.shape[1], 1)
            return x

    def forecast(self, x_enc, x_mark_enc=None):
        """Encode the input sequence and produce complex decoder states."""
        enc_src = self.layer_norm(self.embedding(x_enc, x_mark_enc))
        raw_fft = enc_src.transpose(1, 2)
        mask_proj = weights_proj = None

        if self.band_preprocessor is not None:
            tokens, mask_proj, weights_proj = self.band_preprocessor.process_fft(
                raw_fft)

            if self.channel_mixer is not None and self.use_channel_mixing:
                raw_fft, mixing_matrix, gamma, phi = self.channel_mixer(
                    raw_fft, mask_proj=mask_proj, weights_proj=weights_proj, use_topk=True
                )
                self.last_channel_priors = {
                    'mixing_matrix': mixing_matrix,
                    'gamma': gamma,
                    'phi': phi,
                    'used_topk': getattr(self.config, 'mixing_topk', 16),
                }
            else:
                self.last_channel_priors = None

            if self.use_adaptive_fusion and self.adaptive_fusion is not None:
                enc_src = self.adaptive_fusion(
                    raw_fft=raw_fft,
                    tokens=tokens,
                    mask_proj=mask_proj,
                    weights_proj=weights_proj,
                ).transpose(1, 2)
            else:
                enc_src = raw_fft.transpose(1, 2)
        else:
            enc_src = raw_fft.transpose(1, 2)
            self.last_channel_priors = None

        if self.channel_mixer is not None and self.use_guided_gating:
            try:
                gamma, phi, w_eff = self.channel_mixer.compute_priors(
                    raw_fft, mask_proj=mask_proj, weights_proj=weights_proj, use_topk=True
                )
                s_gate = self.channel_mixer.compute_channel_gate(
                    raw_fft, w_eff=w_eff)
                B, L, _ = enc_src.shape
                C = s_gate.shape[-1]
                if L != C:
                    repeat = max(L // C, 1)
                    g = s_gate.repeat_interleave(repeat, dim=1)
                    if g.shape[1] != L:
                        g = F.interpolate(g.unsqueeze(
                            1), size=L, mode='linear', align_corners=False).squeeze(1)
                else:
                    g = s_gate
                g = (g.clamp(0, 1) * self.gate_scale) + self.gate_bias
                enc_src = enc_src * g.unsqueeze(-1)
                if self.last_channel_priors is None:
                    mixing_matrix = self.channel_mixer.compute_mixing(
                        gamma, phi)
                    self.last_channel_priors = {
                        'mixing_matrix': mixing_matrix,
                        'gamma': gamma,
                        'phi': phi,
                        'used_topk': getattr(self.config, 'mixing_topk', 16),
                    }
            except Exception as exc:
                # Gating is advisory; warn and continue if numerical issues appear.
                print(f'Warning: gating skipped ({exc})')

        self.last_encoder_input = enc_src
        self.last_encoder_input_info = {
            'C': self.config.enc_in, 'L': enc_src.shape[1]}

        enc_out, attns = self.encoder(enc_src)
        self.last_attn = attns[-1] if attns is not None else None

        c_orig = self.config.enc_in
        dec_out = self.projection(enc_out).transpose(1, 2)[:, :, :c_orig]
        return dec_out

    def get_last_channel_repr(self):
        """Return cached encoder inputs realigned to channel-major layout."""
        if not hasattr(self, 'last_encoder_input') or self.last_encoder_input is None:
            return None
        x = self.last_encoder_input
        info = getattr(self, 'last_encoder_input_info', None)
        if info is None:
            return None
        B, L, D = x.shape
        C = info.get('C', None)
        if C is None or C <= 0:
            return None
        if L == C:
            return x
        if L % C == 0:
            groups = L // C
            return x.view(B, groups, C, D).mean(dim=1)
        return x[:, :C, :]

    def forward(self, x_enc, x_mark_enc=None, *args, **kwargs):
        """Execute the full FACT pipeline and output the forecast horizon."""
        if self.task_name != 'long_term_forecast':
            raise NotImplementedError(
                'Minimal release only supports long-term forecasting.')

        if self.use_revin:
            x_enc = self.revin_layer(x_enc, 'norm')

        if self.use_norm:
            x_enc, means, stdev = self._norm(x_enc, forward=True)

        out = self.forecast(x_enc, x_mark_enc)

        if self.use_norm:
            out = self._norm(
                out,
                forward=False,
                stats=(means[:, 0:1, :], stdev[:, 0:1, :]),
            )

        if self.use_revin:
            out = self.revin_layer(out, 'denorm')

        return out[:, -self.pred_len:, :]
