import math
import torch
from torch import nn
import torch.nn.functional as F
from baselines.layers.Transformer_EncDec import Encoder, EncoderLayer, TimeXer_Encoder, TimeXer_EncoderLayer, SelfAttentionOnlyLayer, CrossAttentionOnlyLayer, CrossAttentionFFNLayer
from baselines.layers.SelfAttention_Family import FullAttention, AttentionLayer
from baselines.layers.Embed import PatchTST_Embedding, DataEmbedding_inverted, TimeXer_EnEmbedding, PatchTST_Embedding_add, PatchTST_Embedding_concat, DataEmbedding_inverted_add, DataEmbedding_inverted_concat
from modules import TransformerBlock1
# from baselines.layers.Mamba_EncDec import Encoder as MambaEncoder, EncoderLayer as MambaEncoderLayer
# from mamba_ssm import Mamba
# from modules import TransformerBlock1, CrossAttention, CrossAttentionBlock
# from baselines.layers.Embed import PatchTST_Embedding, DataEmbedding_inverted, TimeXer_EnEmbedding
    
    
class multiTimeAttention(nn.Module):
    def __init__(self, input_dim, nq=128, embed_time=16, num_heads=1, npatch=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = math.ceil(nq / npatch)
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(self.nhidden, self.embed_time_k)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.permute(0, 3, 1, 2).unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.permute(0, 3, 1, 2).unsqueeze(-3), -1), p_attn
    
    
    def forward(self, query, key, value, mask=None, dropout=None, npatch=None):
        batch, _, dim = value.size()
        if mask is not None:
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        query = self.linears[0](query).view(query.size(0), -1, self.h, self.embed_time_k).transpose(1,2).unsqueeze(1)
        key = self.linears[1](key).view(key.size(0), key.size(1), key.size(2), self.h, self.embed_time_k).transpose(1,2).transpose(2,3)
        x, _ = self.attention(query, key, value, mask, dropout)
        if npatch is not None:
            d_k = x.size(-1) 
            
            # 1. 패치당 길이(patch_len) 계산 및 패딩
            patch_len = math.ceil(d_k / npatch)
            total_required_len = patch_len * npatch
            pad_len = total_required_len - d_k
            if pad_len > 0:
                x = F.pad(x, (0, pad_len)) # 마지막 차원(dim)에 0 추가
            
            # 2. 패치 분할: [B, V, H, npatch, patch_len]
            x = x.view(batch, dim, self.h, npatch, patch_len)
            
            # 3. 패치 임베딩
            x = self.linears[-1](x).transpose(2,3)  # [B, V, npatch, H, embed_time_k]
            x = x.reshape(batch, dim, npatch, -1)  # [B, V, npatch, embed_time]
        else:
            x = self.linears[-1](x)
            x = x.reshape(batch, dim, self.h * self.embed_time_k)
        return x


class ResBlock(nn.Module):
    def __init__(self, args):
        super(ResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.Dropout(args.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x):
        B, N, M, D = x.shape

        x_temp = x.reshape(B, N, M * D)
        x_temp = x_temp + self.temporal(x_temp)
        x = x_temp.reshape(B, N, M, D)

        x_chan = x.transpose(1, 2).reshape(B, M, N * D)
        x_chan = self.channel(x_chan)
        x_chan = x_chan.reshape(B, M, N, D).transpose(1, 2)
        x = x + x_chan
        
        return x
    
    
class iResBlock(nn.Module):
    def __init__(self, args):
        super(iResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(args.hid_dim, 2*args.hid_dim),
            nn.ReLU(),
            nn.Linear(2*args.hid_dim, args.hid_dim),
            nn.Dropout(args.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(args.ndim, 2*args.hid_dim),
            nn.ReLU(),
            nn.Linear(2*args.hid_dim, args.ndim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x): # [B, N, D]}
        x = x + self.channel(x.transpose(1,2)).transpose(1,2)
        
        x = x + self.temporal(x)
        
        return x


class TempBlock(nn.Module):
    def __init__(self, args):
        super(TempBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x):
        x = x + self.temporal(x)
        
        return x
    
    
class ChanBlock(nn.Module):
    def __init__(self, args):
        super(ChanBlock, self).__init__()

        self.channel = nn.Sequential(
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x):
        x = x + self.channel(x)
        
        return x


class PatchMixerLayer(nn.Module):
    def __init__(self,dim, a, kernel_size = 8):
        super().__init__()
        self.Resnet =  nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=kernel_size, groups=dim, padding='same'),
            nn.GELU(),
            nn.BatchNorm1d(dim)
        )
        self.Conv_1x1 = nn.Sequential(
            nn.Conv1d(dim, a, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm1d(a)
        )
        
    def forward(self,x):
        x = x +self.Resnet(x)
        x = self.Conv_1x1(x)
        return x


class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False): 
        super().__init__()
        self.dims, self.contiguous = dims, contiguous
    def forward(self, x):
        if self.contiguous: return x.transpose(*self.dims).contiguous()
        else: return x.transpose(*self.dims)


class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.irr_emb = args.irr_emb
        self.model = args.model
        self.mode = args.mode
        self.device = args.device
        self.d_model = args.hid_dim
        self.factor = 5
        
        # Embeddings
        self.te_scale = nn.Linear(1, 1)
        if self.irr_emb is not True and self.model == "segrnn":
            self.te_periodic = nn.Linear(1, (self.d_model // 2) - 1)
        else:
            self.te_periodic = nn.Linear(1, self.d_model - 1)
        
        
        # IMTS_Token_Embedding
        if (self.irr_emb):
            self.val_emb = nn.Linear(1, self.d_model)
            
            if self.model == "itransformer" or self.model == "s_mamba" or self.model == "tsmixer_v":
                if self.mode != 'mean':
                    self.var_emb = nn.Embedding(args.ndim, self.d_model)
            elif self.model == "patchtst" or self.model == "segrnn" or self.model == "patchmixer" or self.model == "tsmixer" or self.model == "tsmixer_p" or self.model == "p_slstm":
                if self.mode != 'mean':
                    self.patch_emb = nn.Embedding(args.npatch, self.d_model)
            elif self.model == "timexer":
                self.patch_emb = nn.Embedding(args.npatch, self.d_model)
                self.var_emb = nn.Embedding(args.ndim, self.d_model)
                self.ex_embedding = TransformerBlock1(self.d_model, args.nhead, dropout=args.dropout)
                #self.ex_embedding = DataEmbedding_inverted(args.maxlen*args.npatch, self.d_model, args.dropout)
            
            if self.mode in ['self', 'mean']:
                self.embedding_layer = TransformerBlock1(self.d_model, args.nhead, dropout=args.dropout)
                # SelfAttentionOnlyLayer(
                #     AttentionLayer(
                #         FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False), self.d_model, args.nhead),
                #     self.d_model,
                #     args.dropout
                # )
            elif self.mode == 'cross':
                self.embedding_layer = CrossAttentionOnlyLayer(
                    AttentionLayer(
                        FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False), self.d_model, args.nhead),
                    self.d_model,
                    args.dropout
                )
            elif self.mode == 'mtand':
                if self.model == "itransformer" or self.model == "s_mamba" or self.model == "tsmixer_v":
                    self.embedding_layer = multiTimeAttention(args.ndim, 128, self.d_model, 1)
                else:
                    self.embedding_layer = multiTimeAttention(args.ndim, 128, self.d_model, 1, npatch=args.npatch)
                
        else:
            if self.model == "patchtst" or self.model == "tsmixer" or self.model == "tsmixer_p":
                if self.mode == 'add':
                    self.embedding_layer = PatchTST_Embedding_add(args.maxlen, self.d_model, args.dropout)
                elif self.mode == 'concat':
                    self.embedding_layer = PatchTST_Embedding_concat(args.maxlen, self.d_model, args.dropout)
                else:
                    self.embedding_layer = PatchTST_Embedding(args.maxlen, self.d_model, args.dropout)
            elif self.model == "itransformer" or self.model == "s_mamba" or self.model == "tsmixer_v":
                if self.mode == 'add':
                    self.embedding_layer = DataEmbedding_inverted_add(args.maxlen, self.d_model, args.dropout)
                elif self.mode == 'concat':
                    self.embedding_layer = DataEmbedding_inverted_concat(args.maxlen, self.d_model, args.dropout)
                else:
                    self.embedding_layer = DataEmbedding_inverted(args.maxlen, self.d_model, args.dropout)
            elif self.model == "timexer":
                self.en_embedding = TimeXer_EnEmbedding(args.ndim, self.d_model, args.maxlen, args.dropout)
                self.ex_embedding = DataEmbedding_inverted(args.maxlen*args.npatch, self.d_model, args.dropout)
            elif self.model == "segrnn":
                self.embedding_layer = nn.Sequential(
                    nn.Linear(args.maxlen, self.d_model),
                    nn.ReLU()
                    )
            elif self.model == "patchmixer":
                self.embedding_layer = nn.Linear(args.maxlen, self.d_model)  
             
                
        # Encoder
        if self.model == 'patchtst':
            norm_layer = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2))
        elif self.model == "itransformer" or self.model == "s_mamba" or self.model == "timexer":
            norm_layer = torch.nn.LayerNorm(self.d_model)

        if self.model == "timexer":
            self.encoder = TimeXer_Encoder(
                [
                    TimeXer_EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False),
                            self.d_model, args.nhead),
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False),
                            self.d_model, args.nhead),
                        self.d_model,
                        self.d_model,
                        dropout=args.dropout,
                        activation='gelu'
                    )
                    for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )
        elif self.model == "s_mamba":
            d_state = getattr(args, 'd_state', 16) 
            d_conv = getattr(args, 'd_conv', 2)
            expand = getattr(args, 'expand', 1)
            
            self.encoder = MambaEncoder(
                [
                    MambaEncoderLayer(
                        Mamba(
                            d_model=self.d_model,
                            d_state=d_state,
                            d_conv=d_conv,
                            expand=expand,
                        ),
                        Mamba(
                            d_model=self.d_model,
                            d_state=d_state,
                            d_conv=d_conv,
                            expand=expand,
                        ),
                        self.d_model,
                        getattr(args, 'd_ff', 4 * self.d_model),
                        dropout=args.dropout,
                        activation='gelu'
                    ) for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )
        elif self.model == "segrnn":
            self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, batch_first=True, bidirectional=False)
            if self.irr_emb is not True:
                self.channel_emb = nn.Parameter(torch.randn(args.ndim, self.d_model // 2))
            else:
                self.channel_emb = nn.Parameter(torch.randn(args.ndim, self.d_model))      
        elif self.model == "patchmixer":
            self.dropout = nn.Dropout(args.dropout)
            self.encoder_blocks = nn.ModuleList([])
            for _ in range(1):
                self.encoder_blocks.append(PatchMixerLayer(dim=args.npatch, a=args.npatch))
        elif self.model == "tsmixer":
            self.encoder_blocks = nn.ModuleList([ResBlock(args) for _ in range(args.nlayer)])
        elif self.model == "tsmixer_p":
            self.encoder_blocks = nn.ModuleList([TempBlock(args) for _ in range(args.nlayer)])
        elif self.model == "tsmixer_v":
            self.encoder_blocks = nn.ModuleList([iResBlock(args) for _ in range(args.nlayer)])
        else:
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False), 
                            self.d_model, args.nhead
                        ),
                        self.d_model,
                        self.d_model,
                        dropout=args.dropout,
                        activation='gelu'
                    ) for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )
            
            
        # Decoder
        if self.model != "segrnn":
            self.cross_transformer = CrossAttentionFFNLayer(
                AttentionLayer(
                    FullAttention(False, self.factor, attention_dropout=args.dropout,
                                output_attention=False), self.d_model, args.nhead),
                    self.d_model,
                    self.d_model,
                    dropout=args.dropout,
                    activation='gelu'
                )
            
        self.decoder = nn.Sequential(
                nn.Linear(self.d_model, self.d_model),
                nn.ReLU(inplace=True),
                nn.Linear(self.d_model, self.d_model),
                nn.ReLU(inplace=True),
                nn.Linear(self.d_model, 1)
            )
        
        
    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        if (self.irr_emb):
            if self.model == "itransformer" or self.model == "s_mamba" or self.model == "tsmixer_v":
                B, M, L_in, N = X.shape
                X = X.reshape(B, M*L_in, N)
                truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
                mask = mask.reshape(B, M*L_in, N)
                
                B, L, N = X.shape
                x_input = self.val_emb(X.unsqueeze(-1)) # (B, L, N, D)
                time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                
                if self.mode == 'self':
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat
                    
                    mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
    
                    z_cat_flat = torch.cat([var_flat, z_flat], dim=1)
                    mask_tokens = torch.ones(B * N, 1, device=self.device)
                    z_mask_flat = torch.cat([mask_tokens, mask_flat], dim=1)
                    
                    total_l = z_cat_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_cat_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, total_l, self.d_model).permute(0, 2, 1, 3) # (B, L+1, N, D)    
                    
                    variable_embeddings = x_processed[:, 0, :, :]  # (B, N, D)
                
                elif self.mode == 'cross':
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_tokens_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

                    z = x_input + time_emb
                    z_flat = z.reshape(B * N, L, self.d_model)    
                    
                    mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
                    
                    processed_z_flat = self.embedding_layer(var_tokens_flat, z_flat, attn_mask=mask_flat.unsqueeze(-2))
                    variable_embeddings = processed_z_flat.reshape(B, N, self.d_model) # (B, N, D) 
                
                elif self.mode == 'mtand':
                    cls_query = torch.linspace(0, 1., 128)
                    cls_query = self.LearnableTE(cls_query.unsqueeze(0).unsqueeze(-1).to(self.device))
                    variable_embeddings = self.embedding_layer(cls_query, time_emb, X, mask)
                    
                elif self.mode == 'mean':
                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat

                    z_mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
                    
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, L, self.d_model).permute(0, 2, 1, 3) # (B, L, N, D)    
            
                    variable_embeddings = torch.mean(x_processed, dim=1)  # (B, N, D)
                    
                if self.model == "tsmixer_v":
                    #variable_embeddings_flat = variable_embeddings.reshape(B, N*self.d_model)  # (B, N*D)
                    for encoder in self.encoder_blocks:
                        variable_embeddings = encoder(variable_embeddings)  # (B, N*D)
                    irr_z = variable_embeddings#.reshape(B, N, self.d_model) # (B, N, D)
                else:
                    irr_z, _ = self.encoder(variable_embeddings) # (B, N, D)
                irr_z_flat = irr_z.unsqueeze(1).reshape(B*N, 1, self.d_model) # (B*N, 1, D)
                

            elif self.model == "patchtst" or self.model == "timexer" or self.model == "segrnn" or self.model == "patchmixer" or self.model == "tsmixer" or self.model == "tsmixer_p":
                if self.mode != "mtand":
                    B, M, L_in, N = X.shape
                    X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
                    mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
                    truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

                    x_input = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
                    time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
                
                if self.mode == 'self':
                    patch_vectors = self.patch_emb(torch.arange(M, dtype=torch.long, device=self.device))
                    patch_tokens = patch_vectors.reshape(1, M, 1, 1, self.d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
                    
                    patch_vs = x_input + time_emb  # (B, M, N, L_in, D)
                    z = torch.cat([patch_tokens, patch_vs], dim=3) # (B, M, N, L_in+1, D)
                    z_flat = z.reshape(B * M * N, L_in + 1, self.d_model)
                    attn_mask = torch.cat([torch.ones(B, M, N, 1, device=self.device), mask], dim=3).reshape(B * M * N, L_in + 1)
                    
                    total_l = z_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=attn_mask.unsqueeze(-1))
                    processed_z = processed_z_flat.reshape(B, M, N, total_l, self.d_model)
                    
                    patch_embeddings = processed_z[:, :, :, 0, :].transpose(1,2)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)
                
                elif self.mode == 'cross':
                    patch_vectors = self.patch_emb(torch.arange(M, dtype=torch.long, device=self.device))
                    patch_tokens = patch_vectors.reshape(1, M, 1, 1, self.d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
                    patch_tokens_flat = patch_tokens.reshape(B * M * N, 1, self.d_model)
                    
                    patch_vs = x_input + time_emb  # (B, M, N, L_in, D)
                    patch_vs_flat = patch_vs.reshape(B * M * N, L_in, self.d_model)
                    
                    mask_flat = mask.reshape(B * M * N, L_in)
                    
                    processed_z_flat = self.embedding_layer(patch_tokens_flat, patch_vs_flat, attn_mask=mask_flat.unsqueeze(-2)) # (B*M*N, 1, D)
                    patch_embeddings = processed_z_flat.reshape(B, M, N, self.d_model).transpose(1,2)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)  # (B*N, M, D)
                
                elif self.mode == 'mtand':
                    B, M, L_in, N = X.shape
                    X = X.reshape(B, M*L_in, N)
                    truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
                    mask = mask.reshape(B, M*L_in, N)
                    
                    key = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                
                    cls_query = torch.linspace(0, 1., 128)
                    cls_query = self.LearnableTE(cls_query.unsqueeze(0).unsqueeze(-1).to(self.device))
                    patch_embeddings = self.embedding_layer(cls_query, key, X, mask, npatch=M)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)  # (B*N, M, D)
                    
                elif self.mode == 'mean':
                    z = x_input + time_emb  # (B, M, N, L_in, D)
                    z_flat = z.reshape(B * M * N, L_in, self.d_model)
                    attn_mask = mask.reshape(B * M * N, L_in)
                    
                    total_l = z_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=attn_mask.unsqueeze(-1))
                    processed_z = processed_z_flat.reshape(B, M, N, total_l, self.d_model)  # (B, M, N, L_in, D)
                    
                    patch_embeddings = torch.mean(processed_z, dim=-2).transpose(1,2)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)
                    
                    
                # Encoder
                if self.model == "patchtst":
                    irr_z_flat, _ = self.encoder(patch_embeddings_flat) # (B*N, M, D)
                    
                elif self.model == "timexer":
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)
                    
                    en_embed = torch.cat([patch_embeddings_flat, var_flat], dim=1)
                    
                    B, M, N, L_in = X.shape
                    X = X.transpose(-1,-2).reshape(B, M*L_in, N)
                    truth_time_steps = truth_time_steps.transpose(-1,-2).reshape(B, M*L_in, N)
                    mask = mask.transpose(-1,-2).reshape(B, M*L_in, N)
                    
                    B, L, N = X.shape
                    x_input = self.val_emb(X.unsqueeze(-1)) # (B, L, N, D)
                    time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                    
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat
                        
                    mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
        
                    z_cat_flat = torch.cat([var_flat, z_flat], dim=1)
                    mask_tokens = torch.ones(B * N, 1, device=self.device)
                    z_mask_flat = torch.cat([mask_tokens, mask_flat], dim=1)
                        
                    total_l = z_cat_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_cat_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, total_l, self.d_model).permute(0, 2, 1, 3) # (B, L+1, N, D)    
                        
                    ex_embed = x_processed[:, 0, :, :]  # (B, N, D)
                    
                    irr_z_flat = self.encoder(en_embed, ex_embed)  # (B*N, M+1, D)
                    
                elif self.model == "segrnn":
                    patch_embeddings_flat = patch_embeddings_flat + self.position_embedding(patch_embeddings_flat)
                    
                    _, irr_z_flat = self.rnn(patch_embeddings_flat)  # (1, B*N, d_model)
                    
                    L_pred = time_steps_to_predict.shape[1]
                    te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1))  # (B, L_pred, D)
                    te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, self.d_model)  # (B*N, L_pred, D)
                    
                    channel_emb_flat = self.channel_emb.unsqueeze(1).expand(-1, L_pred, -1)  # (N, L_pred, D)
                    channel_emb_flat = channel_emb_flat.unsqueeze(0).expand(B, -1, -1, -1)  # (B, N, L_pred, D)
                    channel_emb_flat = channel_emb_flat.reshape(B * N, L_pred, self.d_model)  # (B*N, L_pred, D)
                    
                    pos_emb = te_pred_flat + channel_emb_flat
                    pos_emb = pos_emb.view(-1, 1, self.d_model)  # (B*N*L_pred, 1, d_model)
                    
                    _, dec_out = self.rnn(pos_emb, irr_z_flat.repeat(1, 1, L_pred).view(1, -1, self.d_model))

                    dec_out = dec_out.squeeze(0).reshape(B, L_pred, N, self.d_model)
                    dec_out = self.decoder(dec_out) # (B, L_pred, N, 1)
                    
                    return dec_out.squeeze(-1) # (B, L_pred, N)
                
                elif self.model == "patchmixer":
                    patch_embeddings_flat = self.dropout(patch_embeddings_flat)
                    
                    for encoder in self.encoder_blocks:
                        patch_embeddings_flat = encoder(patch_embeddings_flat)
                        
                    irr_z_flat = patch_embeddings_flat
                    
                elif self.model == "tsmixer":
                    patch_embeddings = patch_embeddings_flat.reshape(B, N, M, self.d_model) # (B, N, M, D)

                    for encoder in self.encoder_blocks:
                        patch_embeddings = encoder(patch_embeddings) # (B, N, M, D)
                    
                    irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
                    
                elif self.model == "tsmixer_p":
                    patch_embeddings = patch_embeddings_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                    
                    for encoder in self.encoder_blocks:
                        patch_embeddings = encoder(patch_embeddings) # (B, N, M*D)
                        
                    irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
            
                    
        else:
            if self.model == "itransformer" or self.model == "s_mamba":
                B, L, N = X.shape
                
                if self.mode == 'add' or self.mode == 'concat':
                    t = truth_time_steps.unsqueeze(-1).expand(-1, -1, N)
                    variable_embeddings = self.embedding_layer(X, t)
                else:
                    variable_embeddings = self.embedding_layer(X, None)
                
                if self.model == "s_mamba": 
                    irr_z, _ = self.encoder(variable_embeddings)
                    irr_z_flat = irr_z.unsqueeze(1).reshape(B*N, 1, self.d_model)
                else: 
                    irr_z, _ = self.encoder(variable_embeddings)
                    irr_z_flat = irr_z.unsqueeze(1).reshape(B*N, 1, self.d_model)
            
            elif self.model == "patchtst":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                if self.mode == 'add' or self.mode == 'concat':
                    t_enc = truth_time_steps.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                    patch_embeddings_flat = self.embedding_layer(x_enc, t_enc)
                else:
                    patch_embeddings_flat = self.embedding_layer(x_enc)
                
                irr_z_flat, _ = self.encoder(patch_embeddings_flat) # (B*N, M, D)
            
            elif self.model == "timexer":
                B, M, L_in, N = X.shape
                
                en_embed = self.en_embedding(X.permute(0, 3, 1, 2))
                ex_embed = self.ex_embedding(X.reshape(B, M*L_in, N), None)
                
                irr_z_flat = self.encoder(en_embed, ex_embed) # (B*N, M, D)
            
            elif self.model == "segrnn":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B * N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)
            
                _, irr_z_flat = self.rnn(patch_embeddings_flat)  # (1, B*N, d_model)
                
                L_pred = time_steps_to_predict.shape[1]
                te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1))  # (B, L_pred, D/2)
                te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, self.d_model // 2)  # (B*N, L_pred, D/2)
                
                channel_emb_flat = self.channel_emb.unsqueeze(1).expand(-1, L_pred, -1)  # (N, L_pred, D/2)
                channel_emb_flat = channel_emb_flat.unsqueeze(0).expand(B, -1, -1, -1)  # (B, N, L_pred, D/2)
                channel_emb_flat = channel_emb_flat.reshape(B * N, L_pred, self.d_model // 2)  # (B*N, L_pred, D/2)
                
                pos_emb = torch.cat([te_pred_flat, channel_emb_flat], dim=-1).view(-1, 1, self.d_model)  # (B*N*L_pred, 1, d_model)
                
                _, dec_out = self.rnn(pos_emb, irr_z_flat.repeat(1, 1, L_pred).view(1, -1, self.d_model))
                dec_out = dec_out.squeeze(0).reshape(B, L_pred, N, self.d_model)
                dec_out = self.decoder(dec_out) # (B, L_pred, N, 1)
                
                return dec_out.squeeze(-1) # (B, L_pred, N)
            
            elif self.model == "patchmixer":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)  # (B*N, M, D)
                patch_embeddings_flat = self.dropout(patch_embeddings_flat)

                for encoder in self.encoder_blocks:
                    patch_embeddings_flat = encoder(patch_embeddings_flat)
                irr_z_flat = patch_embeddings_flat
                
            elif self.model == "tsmixer":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)  # (B*N, M, D)
                patch_embeddings = patch_embeddings_flat.reshape(B, N, M, self.d_model) # (B, N, M, D)
                
                for encoder in self.encoder_blocks:
                    patch_embeddings = encoder(patch_embeddings) # (B, N, M, D)
                    
                irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
                
            elif self.model == "tsmixer_p":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)  # (B*N, M, D)
                patch_embeddings = patch_embeddings_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                
                for encoder in self.encoder_blocks:
                    patch_embeddings = encoder(patch_embeddings) # (B, N, M*D)
                    
                irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
                
            elif self.model == "tsmixer_v":
                B, L, N = X.shape
                variable_embeddings = self.embedding_layer(X, None)  # (B, N, D)
                #variable_embeddings_flat = variable_embeddings.reshape(B, N*self.d_model)  # (B, N*D)
                
                for encoder in self.encoder_blocks:
                    variable_embeddings = encoder(variable_embeddings)  # (B, N*D)
                    
                irr_z_flat = variable_embeddings.unsqueeze(-2).reshape(B*N, 1, self.d_model) # (B*N, 1, D)
                
        
        # Decoder
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)
        te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, self.d_model) # (B*N, L_pred, D)
            
        out_flat, _ = self.cross_transformer(te_pred_flat, irr_z_flat, irr_z_flat) # (B*N, L_pred, D)
        out = out_flat.reshape(B, N, L_pred, self.d_model).permute(0, 2, 1, 3) # (B, L_pred, N, D)
        dec_out = self.decoder(out) # (B, L_pred, N, 1)
                        
        return dec_out.squeeze(-1) # (B, L_pred, N)