import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba
from models.RevIN import RevIN
from models.ttt import TTTLinear, TTTMLP, TTTConfig
from models.iTransformer import Model as iTransformerModel

class Configs:
    def __init__(self, args=None):
        # Task-specific configurations
        self.task_name = args.task_name
        self.pred_len = args.pred_len
        self.seq_len = args.seq_len
        
        # Model dimensions and embedding configurations
        self.d_model = args.d_model
        self.d_conv = args.d_conv
        self.expand = args.expand
        self.enc_in = args.enc_in
        self.embed = 'timeF'
        self.freq = args.freq
        self.dropout = args.dropout
        self.expand = args.expand

        self.top_k = args.top_k
        self.num_kernels = args.num_kernels
        self.factor = 5
        self.activation = 'gelu'

        # can reset the number of heads to match ttt
        self.n_heads = 8
        # self.n_heads = args.n_heads  # Number of attention heads
        self.e_layers = args.e_layers  # Number of encoder layers
        self.d_layers = args.d_layers  # Number of decoder layers
        self.d_ff = 2048  # Feedforward dimension
        
        # Output configurations
        self.c_out = args.c_out
        
        self.embed_type = "timeF"


class TTT2D(TTTLinear):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int = 8,
        mini_batch_size: int = 16,
        **kwargs,
    ):
        # ➊ tell the TTTBase we want the wide hidden size / heads
        # if cfg is None:
        #     cfg = TTTConfig(
        #         hidden_size = hidden_size,
        #         num_attention_heads = num_attention_heads,
        #         mini_batch_size = mini_batch_size,
        #     )
        config = TTTConfig()

        self.var_attn_model = iTransformerModel(config)
        self.xlayernorm = nn.LayerNorm(config.d_model)

        self.scalar1 = nn.Parameter(torch.ones(1))
        self.scalar2 = nn.Parameter(torch.ones(1))

        super().__init__(config, **kwargs)
    def forward(self, x_enc, cache_params  = None, **kwargs):
        """
        Arguments
        ---------
        hidden_states : [B, L, C]
            Standard token-time sequence.
        kv_pair       : tuple(torch.Tensor, torch.Tensor) | None
            Keys / Values from an iTransformer run over the *variate*
            dimension **at the same resolution** (n1 or n2).
        Returns
        -------
        out : [B, L, C]  – post-norm, projected through o_proj.
        """

        var_repr, attn_list, (keys,values) = self.var_attn_model(     # returns embedding + scores
            x_enc, x_mark_enc=None,
            x_dec=None, x_mark_dec=None,
            return_attn=True
        )

        K2, V2 = keys, values

        #inject into ttt linear
        return self.scalar1 * var_repr + self.scalar2 * super().forward(
            x_enc,
            cache_params = cache_params,
            var_attention= (K2, V2),
            **kwargs,
        )
    

class TTT_TImeMachine(torch.nn.Module):
    def __init__(self,configs):
        super(TTT_TImeMachine, self).__init__()
        self.configs=configs
        if self.configs.revin==1:
            self.revin_layer = RevIN(self.configs.enc_in)

        self.lin1=torch.nn.Linear(self.configs.seq_len,self.configs.n1)
        self.dropout1=torch.nn.Dropout(self.configs.dropout)

        self.lin2=torch.nn.Linear(self.configs.n1,self.configs.n2)
        self.dropout2=torch.nn.Dropout(self.configs.dropout)
        if self.configs.ch_ind==1:
            self.d_model_param1=1
            self.d_model_param2=1

        else:
            self.d_model_param1=self.configs.n2
            self.d_model_param2=self.configs.n1

        # self.mamba1=Mamba(d_model=self.d_model_param1,d_state=self.configs.d_state,d_conv=self.configs.dconv,expand=self.configs.e_fact) 
        # self.mamba2=Mamba(d_model=self.configs.n2,d_state=self.configs.d_state,d_conv=self.configs.dconv,expand=self.configs.e_fact) 
        # self.mamba3=Mamba(d_model=self.configs.n1,d_state=self.configs.d_state,d_conv=self.configs.dconv,expand=self.configs.e_fact)
        # self.mamba4=Mamba(d_model=self.d_model_param2,d_state=self.configs.d_state,d_conv=self.configs.dconv,expand=self.configs.e_fact)

        self.ttt1 = TTT2D(hidden_size=self.d_model_param1,
                   num_attention_heads=self.configs.n_heads)
        self.ttt2 = TTT2D(hidden_size=self.configs.n2,
                          num_attention_heads=self.configs.n_heads)
        self.ttt3 = TTT2D(hidden_size=self.configs.n1,
                          num_attention_heads=self.configs.n_heads)
        self.ttt4 = TTT2D(hidden_size=self.d_model_param2,
                          num_attention_heads=self.configs.n_heads)

        self.lin3=torch.nn.Linear(self.configs.n2,self.configs.n1)
        self.lin4=torch.nn.Linear(2*self.configs.n1,self.configs.pred_len)





    def forward(self, x):
         if self.configs.revin==1:
             x=self.revin_layer(x,'norm')
         else:
             means = x.mean(1, keepdim=True).detach()
             x = x - means
             stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
             x /= stdev
         
         x=torch.permute(x,(0,2,1))
         if self.configs.ch_ind==1:
             x=torch.reshape(x,(x.shape[0]*x.shape[1],1,x.shape[2]))

         x=self.lin1(x)
         x_res1=x
         x=self.dropout1(x)
         x3=self.ttt3(x)
         if self.configs.ch_ind==1:
             x4=torch.permute(x,(0,2,1))
         else:
             x4=x
         x4=self.ttt4(x4)
         if self.configs.ch_ind==1:
             x4=torch.permute(x4,(0,2,1))

        
         x4=x4+x3
         

         x=self.lin2(x)
         x_res2=x
         x=self.dropout2(x)
         
         if self.configs.ch_ind==1:
             x1=torch.permute(x,(0,2,1))
         else:
             x1=x      
         x1=self.ttt1(x1)
         if self.configs.ch_ind==1:
             x1=torch.permute(x1,(0,2,1))
  
         x2=self.ttt2(x)

         if self.configs.residual==1:
             x=x1+x_res2+x2
         else:
             x=x1+x2
         
         x=self.lin3(x)
         if self.configs.residual==1:
             x=x+x_res1
             
         x=torch.cat([x,x4],dim=2)
         x=self.lin4(x) 
         if self.configs.ch_ind==1:
             x=torch.reshape(x,(-1,self.configs.enc_in,self.configs.pred_len))
         
         x=torch.permute(x,(0,2,1))
         if self.configs.revin==1:
             x=self.revin_layer(x,'denorm')
         else:
             x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.configs.pred_len, 1))
             x = x + (means[:, 0, :].unsqueeze(1).repeat(1, self.configs.pred_len, 1))
        

         return x
    
