import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba, Mamba2
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig
from transformers import AutoModel
from layers.Embed import DataEmbedding, DataEmbedding_inverted, PatchEmbedding
from models.ttt import TTTLinear, TTTMLP, TTTConfig
from models.iTransformer import Model as iTransformerModel

torch.autograd.set_detect_anomaly(True)
class Configs:
    def __init__(self, args=None):
        # Task-specific configurations
        self.task_name = args.task_name
        self.pred_len = args.pred_len
        self.seq_len = args.seq_len
        
        # Model dimensions and embedding configurations
        self.d_model = args.d_model
        self.d_conv = args.d_conv
        self.expand = args.expand
        self.enc_in = args.enc_in
        self.embed = 'timeF'
        self.freq = args.freq
        self.dropout = args.dropout
        self.expand = args.expand

        self.top_k = args.top_k
        self.num_kernels = args.num_kernels
        self.factor = 5
        self.activation = 'gelu'

        # can reset the number of heads to match ttt
        self.n_heads = 16
        # self.n_heads = args.n_heads  # Number of attention heads
        self.e_layers = args.e_layers  # Number of encoder layers
        self.d_layers = args.d_layers  # Number of decoder layers
        self.d_ff = 128  # Feedforward dimension
        
        # Output configurations
        self.c_out = args.c_out
        
        # Model selection
        self.x_model_name = "ttt_linear" 
        self.embed_type = "timeF"
        self.num_layers = 1

        self.num_classes = args.num_classes


        self.patch_len = args.patch_len,
        self.stride = args.stride,
        self.padding = args.padding,
        self.patch_embedding = args.patch_embedding
        

    

class ConvFFN2(nn.Module):
    """
    ModernTCN-style ConvFFN-2 that mixes **across variables (M)**
    but keeps each feature slice (Dh) independent via grouped conv.
    Expects x: [B, M, L, Dh]
    """
    def __init__(self, M: int, Dh: int, expansion: int = 2):
        super().__init__()
        hidden = M * expansion
        # groups = Dh  →  each of the Dh feature maps has its own PW-conv
        self.pw1 = nn.Conv2d(Dh, Dh, kernel_size=1, groups=Dh)  # noop, keeps Dh
        self.pw2 = nn.Conv2d(Dh, Dh, kernel_size=1, groups=Dh)  # learnable α
        self.act = nn.GELU()
        self.ln  = nn.LayerNorm(Dh)

        self.convffn_drop = nn.Dropout(0.01)


    def forward(self, x):                     # x: [B, M, L, Dh]
        # permute so Dh is channel dim for Conv2d over (M×L) grid
        B, M, L, Dh = x.shape
        y = x.permute(0, 3, 2, 1).contiguous()   # [B, Dh, L, M]
        y = self.pw2(self.act(self.pw1(y)))      # grouped 1×1 on Dh slices
        y = y.permute(0, 3, 2, 1)                # back to [B, M, L, Dh]
        # residual + LN on the feature axis
        return self.ln((x + y).reshape(B * M * L, Dh)).reshape(B, M, L, Dh)
    

class PeriodMix(nn.Module):
    """
    Parallel conv branches with different dilation => different effective period.
    """
    def __init__(self, Dh, dils=(1,3,6)):
        super().__init__()
        self.branches = nn.ModuleList([
            nn.Conv2d(Dh, Dh, (1, 3), padding=(0, d), dilation=(1, d), groups=Dh)
            for d in dils])
        self.ln = nn.LayerNorm(Dh)

    def forward(self, x):               # [B, M, L, Dh]
        B,M,L,Dh = x.shape
        y = x.permute(0,3,1,2)          # [B, Dh, M, L]
        out = sum(branch(y) for branch in self.branches) / len(self.branches)
        out = out.permute(0,2,3,1)      # [B, M, L, Dh]
        return self.ln(x + out)

class DWTimeMix(nn.Module):
    def __init__(self, Dh, kernel=31):
        super().__init__()
        pad = (kernel-1)//2
        self.dw = nn.Conv2d(
            in_channels=Dh, out_channels=Dh,
            kernel_size=(1, kernel),           # (var, time)
            padding=(0, pad), groups=Dh)

    def forward(self, x):    # [B, M, L, Dh]
        x = x.permute(0, 3, 1, 2)      # [B, Dh, M, L]
        x = self.dw(x)
        return x.permute(0, 2, 3, 1)   # back to [B, M, L, Dh]



class ProcessingLayer1(nn.Module):
    def __init__(self, configs):
        super(ProcessingLayer1, self).__init__()
        #get itransformer model
        self.var_attn_model = iTransformerModel(configs)

        M  = configs.d_model                 # 128
        Dh = configs.d_model // configs.n_heads   # 8
        self.mix_kv = ConvFFN2(M, Dh)      # ★ new
        
        # Layer norms
        self.xlayernorm = nn.LayerNorm(configs.d_model)
        self.ylayernorm = nn.LayerNorm(configs.d_model)

        self.dw_temporal = DWTimeMix(Dh=Dh, kernel=31)
        self.period_mix = PeriodMix(Dh=Dh, dils=(1,3,6))


    def forward(self, x_enc):
        # 1) run iTransformer in variate‑token mode
        var_repr, attn_list, (keys,values) = self.var_attn_model(     # returns embedding + scores
            x_enc, x_mark_enc=None,
            x_dec=None, x_mark_dec=None,
            return_attn=True
        )

        # print("[DEBUG] var_repr shape: ", var_repr.shape)
        # print("[DEBUG] keys shape: ", keys.shape)
        # print("[DEBUG] values shape: ", values.shape)

        # kv = self.period_mix(self.dw_temporal(keys))
        # keys   = self.mix_kv(kv)
        # values = self.mix_kv(self.period_mix(self.dw_temporal(values)))

        return var_repr, (keys, values)

class ProcessingLayer2(nn.Module):
    def __init__(self, configs):
        super(ProcessingLayer2, self).__init__()
        # Create models
        
        
        ttt_cfg  = TTTConfig(
                hidden_size  = configs.d_model)
        
        if configs.task_name == 'classification':
            ttt_cfg = TTTConfig(
                hidden_size      = configs.d_model,
                mini_batch_size  = configs.enc_in     # <<< 3 here
        )
        
        print("[DEBUG] ttt_cfg hidden_size: ", ttt_cfg.hidden_size)

        
        self.ttt_layer = TTTLinear(ttt_cfg)
        #get itransformer model
        self.var_attn_model = iTransformerModel(configs)

        self.task_name = configs.task_name
    
        # Layer norms
        self.xlayernorm = nn.LayerNorm(configs.d_model)
        self.ylayernorm = nn.LayerNorm(configs.d_model)

        self.scalar1 = nn.Parameter(torch.ones(1))
        self.scalar2 = nn.Parameter(torch.ones(1))

        self._gate = nn.Parameter(torch.zeros(1)) 

        #in line with itransformers projection from [32, 128, 128] -> [32, 128, 96] -> [32, 96, 128]

        self.forecasting_projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)

        self.anomoly_impute_projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)

        self.classify_projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_classes, bias=True)

        self.cls_act   = nn.GELU()
        self.cls_drop  = nn.Dropout(configs.dropout)
    
    def forward(self, x_enc, var_repr, k_v_tuple):
        _, _, N = x_enc.shape
        keys, values = k_v_tuple
        # print("[DEBUG] keys shape: ", keys.shape)
        # print("[DEBUG] values shape: ", values.shape)
        # 2) run TTT model on the output of iTransformer
        ttt_out = self.ttt_layer(x_enc, var_attention=(keys, values))
        
      #  print("[DEBUG] ttt_out shape: ", ttt_out.shape)
      #  print("[DEBUG] var_repr shape: ", var_repr.shape)
        #exactly the same as in itransformer
        if self.task_name == 'short_term_forecast' or self.task_name == 'long_term_forecast':
            ttt_out_proj = self.forecasting_projection(ttt_out).permute(0, 2, 1)
        #    print("[DEBUG] ttt_out_proj shape: ", ttt_out_proj.shape)
            ttt_out_project = ttt_out_proj[:, :, :N].clone()
        #    print("[DEBUG] ttt_out_project shape: ", ttt_out_project.shape)

            weighted_sum = self.scalar1 * ttt_out_project + self.scalar2 * var_repr                     #

            # g = torch.sigmoid(self._gate)      # scalar in (0,1)
            # blended = g * ttt_out_proj + (1 - g) * var_repr      
            final_output = self.xlayernorm(weighted_sum)            
            return final_output
        
        elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 
            ttt_out_proj = self.anomoly_impute_projection(ttt_out).permute(0, 2, 1)
       #     print("[DEBUG] ttt_out_proj shape: ", ttt_out_proj.shape)
            ttt_out_project = ttt_out_proj[:, :, :N].clone()
        #    print("[DEBUG] ttt_out_project shape: ", ttt_out_project.shape)

            weighted_sum = self.scalar1 * ttt_out_project + self.scalar2 * var_repr                     #

            # g = torch.sigmoid(self._gate)      # scalar in (0,1)
            # blended = g * ttt_out_proj + (1 - g) * var_repr      
            final_output = self.xlayernorm(weighted_sum)            
            return final_output
        elif self.task_name == 'classification':
            # 1. Blend the two feature maps very last (still [B, C, d])
            keys = keys.permute(0, 2, 1, 3)          # [B, L, C]
            values = values.permute(0, 2, 1,3)
            # 2. LN + GELU + Dropout
         #   print("[DEBUG] ttt_out shape: ", ttt_out.shape)
            feats = self.cls_act(ttt_out)
        #    print("[DEBUG] feats shape: ", feats.shape) 
            dropout_output = self.cls_drop(feats)                              # [B, C, d]
         #   print("[DEBUG] dropout_output shape: ", dropout_output.shape)

            # 3. Flatten channel & feature → [B, C*d]
            reshaped_output = dropout_output.reshape(dropout_output.size(0), -1)
        #    print("[DEBUG] reshaped_output shape: ", reshaped_output.shape)
        
            # 4. Final logits
            logits  = self.classify_projection(reshaped_output)                   # [B, num_class]
         #   print("[DEBUG] logits shape: ", logits.shape)

            blended = self.scalar1 * logits + self.scalar2 * var_repr   # [B, C, d]
        #    print("[DEBUG] blended shape: ", blended.shape)

            return blended

        else:
            raise ValueError(f"Unknown task name: {self.task_name}")

        return None
    

    
    
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        configs = Configs(args)
        self.task_name = configs.task_name
        self.pred_len = configs.pred_len
        self.d_inner = configs.d_model * configs.expand
        self.n_vars = configs.enc_in



        
        # 1) Build embedding
        if configs.task_name == 'imputation':
            self.embedding = DataEmbedding(
                configs.enc_in, configs.d_model, 
                configs.embed, configs.freq, configs.dropout
            )
        else:
            self.embedding = DataEmbedding(
                configs.enc_in, configs.d_model, 
                configs.embed, configs.freq, configs.dropout
            )

        if configs.task_name == 'short_term_forecast':
            self.enc_embedding = DataEmbedding_inverted(2 * configs.pred_len, configs.d_model, configs.embed, configs.freq,
                                                    configs.dropout)
        else:
            self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
                                                    configs.dropout)

        # 2) for the itransfotmer model
        self.processing_layer1 = ProcessingLayer1(configs)

        #2.5 for the ttt models
        self.processing_layer2 = ProcessingLayer2(configs)


        # 3) Add any final heads, e.g. out_layer
        self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)

    def forecast_no_patching(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        device = x_enc.device
        
        # Basic normalization
        mean_enc = x_enc.mean(dim=1, keepdim=True).detach()
        x_enc = x_enc - mean_enc
        std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()

        x_enc = x_enc / std_enc
        x_enc = x_enc.to(device)

   #     print("[DEBUG] x_enc shape: ", x_enc.shape)
        # 1) Use embedding
        enc_out = self.embedding(x_enc.clone(), x_mark_enc)  
        # shape: [B, L, d_model]
   #     print("[DEBUG] enc embedding shape: ", enc_out.shape)
        # 2) Pass to your ProcessingLayer1 for the itransformer, which in turn calls `self.x_model` (the TTT model)
        var_repr, itrans_kv_tuple = self.processing_layer1(enc_out.clone())

        # # 3) Pass to your ProcessingLayer2 for the TTT model
        enc_out_inverted = self.enc_embedding(enc_out.clone(), None)  
   #     print("[DEBUG] enc_out_inverted shape: ", enc_out_inverted.shape)
        final_model_output = self.processing_layer2(enc_out_inverted, var_repr, itrans_kv_tuple)
  #      print("[DEBUG] final_model_output shape: ", final_model_output.shape)
        # 3) weighted sum of the TTT output and the iTransformer output included in 2nd layer
        final_model_output = self.out_layer(var_repr)
        # Rescale the output back
        x_out = final_model_output * std_enc + mean_enc
        return x_out

    def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        device = x_enc.device
        
        # Basic normalization
        mean_enc = x_enc.mean(dim=1, keepdim=True).detach()
        x_enc = x_enc - mean_enc
        std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()

        x_enc = x_enc / std_enc
        x_enc = x_enc.to(device)

    #    print("[DEBUG] x_enc shape: ", x_enc.shape)
        # 1) Use embedding
        enc_out = self.embedding(x_enc.clone(), x_mark_enc)  
        # shape: [B, L, d_model]
   #     print("[DEBUG] enc embedding shape: ", enc_out.shape)
        # 2) Pass to your ProcessingLayer1 for the itransformer, which in turn calls `self.x_model` (the TTT model)
        var_repr, itrans_kv_tuple = self.processing_layer1(enc_out.clone())

        # # 3) Pass to your ProcessingLayer2 for the TTT model
        enc_out_inverted = self.enc_embedding(enc_out.clone(), None)  
   #     print("[DEBUG] enc_out_inverted shape: ", enc_out_inverted.shape)
        final_model_output = self.processing_layer2(enc_out_inverted, var_repr, itrans_kv_tuple)
  #      print("[DEBUG] final_model_output shape: ", final_model_output.shape)
        # 3) weighted sum of the TTT output and the iTransformer output included in 2nd layer
        final_model_output = self.out_layer(var_repr)
        # Rescale the output back
        x_out = final_model_output * std_enc + mean_enc
        return x_out

    def anomaly_detection(self, x_enc):
        device = x_enc.device
    #    print("[DEBUG] x_enc shape: ", x_enc.shape)
        # Basic normalization
        mean_enc = x_enc.mean(dim=1, keepdim=True).detach()
        x_enc = x_enc - mean_enc
        std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()

        x_enc = x_enc / std_enc
        x_enc = x_enc.to(device)

     #   print("[DEBUG] x_enc shape: ", x_enc.shape)
        # 1) Use embedding
        enc_out = self.embedding(x_enc.clone(), None)  
        # shape: [B, L, d_model]
      #  print("[DEBUG] enc embedding shape: ", enc_out.shape)
        # 2) Pass to your ProcessingLayer1 for the itransformer, which in turn calls `self.x_model` (the TTT model)
        var_repr, itrans_kv_tuple = self.processing_layer1(enc_out.clone())
     #   print("[DEBUG] itrans shape: ", var_repr.shape)
        # # 3) Pass to your ProcessingLayer2 for the TTT model
        enc_out_inverted = self.enc_embedding(enc_out.clone(), None)  
      #  print("[DEBUG] enc_out_inverted shape: ", enc_out_inverted.shape)
        final_model_output = self.processing_layer2(enc_out_inverted, var_repr, itrans_kv_tuple)
  #      print("[DEBUG] final_model_output shape: ", final_model_output.shape)
        # 3) weighted sum of the TTT output and the iTransformer output included in 2nd layer
        final_model_output = self.out_layer(var_repr)
        # Rescale the output back
        x_out = final_model_output * std_enc + mean_enc
        return x_out

    def classification(self, x_enc, x_mark_enc):
        """
        x_enc : [B, L, C]      — raw series
        x_mark_enc : calendar/time features (can be None)

        Returns logits [B, num_class]
        """


   #     print("[DEBUG] x_enc shape: ", x_enc.shape)

        # emb = self.embedding(x_norm, x_mark_enc)          # [B, L, d]
   #     print("[DEBUG] enc embedding shape: ", x_enc.shape)
        var_repr, kv = self.processing_layer1(x_enc)        # [B, C, d]
   #     print("[DEBUG] var_repr shape: ", var_repr.shape)
    #    print("[DEBUG] key shape: ", kv[0].shape)
    #    print("[DEBUG] value shape: ", kv[1].shape)

        inv = self.enc_embedding(x_enc, None)
    #    print("[DEBUG] enc_out_inverted shape: ", inv.shape)
        ttt_combined = self.processing_layer2(inv, var_repr, kv)   # [B,C,d]
   #     print("[DEBUG] ttt_ttt_combinedrepr shape: ", ttt_combined.shape)
        return ttt_combined

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name in ['short_term_forecast','long_term_forecast']:
            # Actually call the forecast logic
            x_out = self.forecast_no_patching(x_enc, x_mark_enc, x_dec, x_mark_dec)
            # Return only the last `self.pred_len` steps if that's what you want
            return x_out[:, -self.pred_len:, :]
        if self.task_name == 'imputation':
            dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
            return dec_out  # [B, L, D]
        if self.task_name == 'anomaly_detection':
            dec_out = self.anomaly_detection(x_enc)
            return dec_out  # [B, L, D]
        if self.task_name == 'classification':
            dec_out = self.classification(x_enc, x_mark_enc)
            return dec_out  # [B, N]
        return None