
import torch
import torch.nn as nn
from layers.Embed import PatchEmbed
from layers.SelfAttention_Family import TSMixer, ResAttention, CrossScaleCohesionAttention
from layers.Transformer_EncDec import TSEncoder, IntAttention, PatchSampling
from layers.periodic_marker import PeriodicMarker
from layers.rlc import RLCRegressor


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.revin = getattr(configs, 'revin', True)  # long-term with temporal

        self.c_in = configs.enc_in
        self.period = getattr(configs, 'period', 1)
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.num_p = self.seq_len // self.period
        if getattr(configs, 'num_p', None) is None:
            configs.num_p = self.num_p

        # patch embedding
        self.embedding = PatchEmbed(configs, num_p=self.num_p)


        self.use_periodic_marker = getattr(configs, 'use_periodic_marker', False)
        self.periodic_alpha = float(getattr(configs, 'periodic_alpha', 0.1))
        if self.use_periodic_marker:
            # embed_dim None => 默认等于 c_in；use_minute False（如需可改）
            self.periodic_marker = PeriodicMarker(c_in=self.c_in, embed_dim=None, use_minute=False)
        else:
            self.periodic_marker = None

        self.use_rlc = getattr(configs, 'use_rlc', False)
        self.rlc_k = int(getattr(configs, 'rlc_k', 4))
        if self.use_rlc:
            # init_scale etc default inside RLCRegressor
            self.rlc = RLCRegressor(c_in=self.c_in, k=self.rlc_k, init_scale=1e-2)
        else:
            self.rlc = None

        layers = self.layers_init(configs)
        self.encoder = TSEncoder(layers)

        out_p = self.num_p if configs.pd_layers == 0 else configs.num_p
        self.decoder = nn.Sequential(
            nn.Flatten(start_dim=-2),
            nn.Linear(out_p * configs.d_model, configs.pred_len, bias=False)
        )


    def layers_init(self, configs):

        integrated_attention = [IntAttention(
            TSMixer(ResAttention(attention_dropout=getattr(configs, 'attn_dropout', 0.1)),
                    configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff,
            dropout=getattr(configs, 'dropout', 0.1),
            stable_len=getattr(configs, 'stable_len', 8),
            activation=getattr(configs, 'activation', 'relu'),
            stable=True, enc_in=self.c_in
        ) for i in range(getattr(configs, 'ia_layers', 0))]


        patch_sampling = [PatchSampling(
            TSMixer(ResAttention(attention_dropout=getattr(configs, 'attn_dropout', 0.1)),
                    configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff,
            stable=False, stable_len=getattr(configs, 'stable_len', 8),
            in_p=self.num_p if i == 0 else configs.num_p, out_p=configs.num_p,
            dropout=getattr(configs, 'dropout', 0.1),
            activation=getattr(configs, 'activation', 'relu')
        ) for i in range(getattr(configs, 'pd_layers', 0))]


        cohesion_attention = [CrossScaleCohesionAttention(
            TSMixer(ResAttention(attention_dropout=getattr(configs, 'attn_dropout', 0.1)),
                    configs.d_model, configs.n_heads),
            configs.d_model, configs.d_ff,
            dropout=getattr(configs, 'dropout', 0.1),
            activation=getattr(configs, 'activation', 'relu'),
            stable=False, enc_in=self.c_in, stable_len=getattr(configs, 'stable_len', 8),
            use_spectral=getattr(configs, 'use_cohesion_spectral', False)
        ) for i in range(getattr(configs, 'ca_layers', 0))]

        return [*integrated_attention, *patch_sampling, *cohesion_attention]

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        """
        x_enc: [B, L, C]
        x_mark_enc: [B, L, D_timefeat] or None
        """
        if x_mark_enc is None:
            x_mark_enc = torch.zeros((*x_enc.shape[:-1], 4), device=x_enc.device)

        # standard per-sample normalization (original behavior)
        mean, std = (x_enc.mean(1, keepdim=True).detach(),
                     x_enc.std(1, keepdim=True).detach())
        # normalized input to embedding
        x_for_embed = (x_enc - mean) / (std + 1e-5)  # [B, L, C]

        # apply periodic gating if enabled (multiplicative emphasis)
        if self.periodic_marker is not None:
            # compute gate from time features: gate in (0,1), shape [B, L, C]
            gate = self.periodic_marker(x_mark_enc)
            # apply multiplicative scaling: x' = x * (1 + alpha * gate)
            x_for_embed = x_for_embed * (1.0 + self.periodic_alpha * gate)

        # pass to embedding and encoder as original
        x_patch = self.embedding(x_for_embed, x_mark_enc)
        enc_out = self.encoder(x_patch)[0][:, :self.c_in, ...]
        dec_out = self.decoder(enc_out).transpose(-1, -2)

        # denormalize as before
        return dec_out * std + mean

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]

    def compute_rlc_loss(self, x_enc, future_y, mode='mean',
                         lambda_aux=1.0, lambda_orth=1.0, reduce='sum'):

        if self.rlc is None:
            return None
        try:

            z, losses = self.rlc.compute_losses(x_enc, future_y, mode=mode)

            rlc_corr = losses.get('rlc_corr', 0.0)
            rlc_aux = losses.get('rlc_aux', 0.0)
            rlc_orth = losses.get('rlc_orth', 0.0)
            combined = rlc_corr + lambda_aux * rlc_aux + lambda_orth * rlc_orth

            return combined
        except Exception as e:
            print("[WARN] RLC computation failed at batch:", e)
            return None


