import math
import torch
from torch import nn
import torch.nn.functional as F
from baselines.layers.Transformer_EncDec import Encoder, EncoderLayer, TimeXer_Encoder, TimeXer_EncoderLayer, SelfAttentionOnlyLayer, CrossAttentionOnlyLayer, CrossAttentionFFNLayer
from baselines.layers.SelfAttention_Family import FullAttention, AttentionLayer
from baselines.layers.Embed import PatchTST_Embedding, DataEmbedding_inverted, TimeXer_EnEmbedding, PatchTST_Embedding_add, PatchTST_Embedding_concat, DataEmbedding_inverted_add, DataEmbedding_inverted_concat
from modules import TransformerBlock1
# from baselines.layers.Mamba_EncDec import Encoder as MambaEncoder, EncoderLayer as MambaEncoderLayer
# from mamba_ssm import Mamba
 
    
class multiTimeAttention(nn.Module):
    def __init__(self, input_dim, nq=128, embed_time=16, num_heads=1, npatch=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = math.ceil(nq / npatch)
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(self.nhidden, self.embed_time_k)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.permute(0, 3, 1, 2).unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.permute(0, 3, 1, 2).unsqueeze(-3), -1), p_attn
    
    
    def forward(self, query, key, value, mask=None, dropout=None, npatch=None):
        batch, _, dim = value.size()
        if mask is not None:
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        query = self.linears[0](query).view(query.size(0), -1, self.h, self.embed_time_k).transpose(1,2).unsqueeze(1)
        key = self.linears[1](key).view(key.size(0), key.size(1), key.size(2), self.h, self.embed_time_k).transpose(1,2).transpose(2,3)
        x, _ = self.attention(query, key, value, mask, dropout)
        if npatch is not None:
            d_k = x.size(-1) 
            
            # 1. 패치당 길이(patch_len) 계산 및 패딩
            patch_len = math.ceil(d_k / npatch)
            total_required_len = patch_len * npatch
            pad_len = total_required_len - d_k
            if pad_len > 0:
                x = F.pad(x, (0, pad_len)) # 마지막 차원(dim)에 0 추가
            
            # 2. 패치 분할: [B, V, H, npatch, patch_len]
            x = x.view(batch, dim, self.h, npatch, patch_len)
            
            # 3. 패치 임베딩
            x = self.linears[-1](x).transpose(2,3)  # [B, V, npatch, H, embed_time_k]
            x = x.reshape(batch, dim, npatch, -1)  # [B, V, npatch, embed_time]
        else:
            x = self.linears[-1](x)
            x = x.reshape(batch, dim, self.h * self.embed_time_k)
        return x


class ResBlock(nn.Module):
    def __init__(self, args):
        super(ResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.npatch*args.hid_dim, args.npatch*args.hid_dim),
            nn.Dropout(args.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.ReLU(),
            nn.Linear(args.ndim*args.hid_dim, args.ndim*args.hid_dim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x):
        B, N, M, D = x.shape

        x_temp = x.reshape(B, N, M * D)
        x_temp = x_temp + self.temporal(x_temp)
        x = x_temp.reshape(B, N, M, D)

        x_chan = x.transpose(1, 2).reshape(B, M, N * D)
        x_chan = self.channel(x_chan)
        x_chan = x_chan.reshape(B, M, N, D).transpose(1, 2)
        x = x + x_chan
        
        return x
    
    
class iResBlock(nn.Module):
    def __init__(self, args):
        super(iResBlock, self).__init__()

        self.temporal = nn.Sequential(
            nn.Linear(args.hid_dim, 2*args.hid_dim),
            nn.ReLU(),
            nn.Linear(2*args.hid_dim, args.hid_dim),
            nn.Dropout(args.dropout)
        )

        self.channel = nn.Sequential(
            nn.Linear(args.ndim, 2*args.hid_dim),
            nn.ReLU(),
            nn.Linear(2*args.hid_dim, args.ndim),
            nn.Dropout(args.dropout)
        )

    def forward(self, x): # [B, N, D]}
        x = x + self.channel(x.transpose(1,2)).transpose(1,2)
        
        x = x + self.temporal(x)
        
        return x


class PatchMixerLayer(nn.Module):
    def __init__(self,dim, a, kernel_size = 8):
        super().__init__()
        self.Resnet =  nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=kernel_size, groups=dim, padding='same'),
            nn.GELU(),
            nn.BatchNorm1d(dim)
        )
        self.Conv_1x1 = nn.Sequential(
            nn.Conv1d(dim, a, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm1d(a)
        )
        
    def forward(self,x):
        x = x +self.Resnet(x)
        x = self.Conv_1x1(x)
        return x


class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False): 
        super().__init__()
        self.dims, self.contiguous = dims, contiguous
    def forward(self, x):
        if self.contiguous: return x.transpose(*self.dims).contiguous()
        else: return x.transpose(*self.dims)


class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.irr_emb = args.irr_emb
        self.model = args.model
        self.mode = args.mode
        self.device = args.device
        self.d_model = args.hid_dim
        self.factor = 5
        
        # Embeddings
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, self.d_model - 1)
        
        # IMTS_Token_Embedding
        if (self.irr_emb):
            self.val_emb = nn.Linear(1, self.d_model)
            
            if self.model == "itransformer" or self.model == "s_mamba":
                if self.mode != 'mean':
                    self.var_emb = nn.Embedding(args.ndim, self.d_model)
            elif self.model == "patchtst" or self.model == "segrnn" or self.model == "patchmixer" or self.model == "tsmixer":
                if self.mode != 'mean':
                    self.patch_emb = nn.Embedding(args.npatch, self.d_model)
            elif self.model == "timexer":
                self.var_emb = nn.Embedding(args.ndim, self.d_model)
                self.ex_embedding = TransformerBlock1(self.d_model, args.nhead, dropout=args.dropout)
                self.patch_emb = nn.Embedding(args.npatch, self.d_model)
            
            if self.mode in ['self', 'mean']:
                self.embedding_layer = TransformerBlock1(self.d_model, args.nhead, dropout=args.dropout)
            elif self.mode == 'mtand':
                if self.model == "itransformer" or self.model == "s_mamba":
                    self.embedding_layer = multiTimeAttention(args.ndim, 128, self.d_model, 1)
                else:
                    self.embedding_layer = multiTimeAttention(args.ndim, 128, self.d_model, 1, npatch=args.npatch)
                
        else:
            if self.model == "patchtst" or self.model == "tsmixer":
                if self.mode == 'add':
                    self.embedding_layer = PatchTST_Embedding_add(args.maxlen, self.d_model, args.dropout)
                elif self.mode == 'concat':
                    self.embedding_layer = PatchTST_Embedding_concat(args.maxlen, self.d_model, args.dropout)
                else:
                    self.embedding_layer = PatchTST_Embedding(args.maxlen, self.d_model, args.dropout)
            elif self.model == "itransformer" or self.model == "s_mamba":
                if self.mode == 'add':
                    self.embedding_layer = DataEmbedding_inverted_add(args.maxlen, self.d_model, args.dropout)
                elif self.mode == 'concat':
                    self.embedding_layer = DataEmbedding_inverted_concat(args.maxlen, self.d_model, args.dropout)
                else:
                    self.embedding_layer = DataEmbedding_inverted(args.maxlen, self.d_model, args.dropout)
            elif self.model == "timexer":
                self.en_embedding = TimeXer_EnEmbedding(args.ndim, self.d_model, args.maxlen, args.dropout)
                self.ex_embedding = DataEmbedding_inverted(args.maxlen*args.npatch, self.d_model, args.dropout)
            elif self.model == "patchmixer":
                self.embedding_layer = nn.Linear(args.maxlen, self.d_model)  
             
                
        # Encoder
        if self.model == 'patchtst':
            norm_layer = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2))
        elif self.model == "itransformer" or self.model == "s_mamba" or self.model == "timexer":
            norm_layer = torch.nn.LayerNorm(self.d_model)

        if self.model == "timexer":
            self.encoder = TimeXer_Encoder(
                [
                    TimeXer_EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False),
                            self.d_model, args.nhead),
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False),
                            self.d_model, args.nhead),
                        self.d_model,
                        self.d_model,
                        dropout=args.dropout,
                        activation='gelu'
                    )
                    for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )
        elif self.model == "s_mamba":
            d_state = getattr(args, 'd_state', 16) 
            d_conv = getattr(args, 'd_conv', 2)
            expand = getattr(args, 'expand', 1)
            
            self.encoder = MambaEncoder(
                [
                    MambaEncoderLayer(
                        Mamba(
                            d_model=self.d_model,
                            d_state=d_state,
                            d_conv=d_conv,
                            expand=expand,
                        ),
                        Mamba(
                            d_model=self.d_model,
                            d_state=d_state,
                            d_conv=d_conv,
                            expand=expand,
                        ),
                        self.d_model,
                        getattr(args, 'd_ff', 4 * self.d_model),
                        dropout=args.dropout,
                        activation='gelu'
                    ) for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )    
        elif self.model == "patchmixer":
            self.dropout = nn.Dropout(args.dropout)
            self.encoder_blocks = nn.ModuleList([])
            for _ in range(1):
                self.encoder_blocks.append(PatchMixerLayer(dim=args.npatch, a=args.npatch))
        elif self.model == "tsmixer":
            self.encoder_blocks = nn.ModuleList([ResBlock(args) for _ in range(args.nlayer)])
        else:
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False), 
                            self.d_model, args.nhead
                        ),
                        self.d_model,
                        self.d_model,
                        dropout=args.dropout,
                        activation='gelu'
                    ) for _ in range(args.nlayer)
                ],
                norm_layer=norm_layer
            )
            
        # Decoder            
        d_static = args.d_static
        if d_static != 0:
            self.emb = nn.Linear(d_static, args.ndim)
            self.classifier = nn.Sequential(
                nn.Linear(args.ndim * 2, 200),
                nn.ReLU(),
                nn.Linear(200, args.n_class))
        else:
            self.classifier = nn.Sequential(
                nn.Linear(args.ndim, 200),
                nn.ReLU(),
                nn.Linear(200, args.n_class))
        
        
    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)
    
    
    def classification(self, X, truth_time_steps, mask=None, P_static=None, feature=False):
        if (self.irr_emb):
            if self.model == "itransformer" or self.model == "s_mamba":
                B, M, L_in, N = X.shape
                X = X.reshape(B, M*L_in, N)
                truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
                mask = mask.reshape(B, M*L_in, N)
                
                B, L, N = X.shape
                x_input = self.val_emb(X.unsqueeze(-1)) # (B, L, N, D)
                time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                
                if self.mode == 'self':
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat
                    
                    mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
    
                    z_cat_flat = torch.cat([var_flat, z_flat], dim=1)
                    mask_tokens = torch.ones(B * N, 1, device=self.device)
                    z_mask_flat = torch.cat([mask_tokens, mask_flat], dim=1)
                    
                    total_l = z_cat_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_cat_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, total_l, self.d_model).permute(0, 2, 1, 3) # (B, L+1, N, D)    
                    
                    variable_embeddings = x_processed[:, 0, :, :]  # (B, N, D)
                
                elif self.mode == 'mtand':
                    cls_query = torch.linspace(0, 1., 128)
                    cls_query = self.LearnableTE(cls_query.unsqueeze(0).unsqueeze(-1).to(self.device))
                    variable_embeddings = self.embedding_layer(cls_query, time_emb, X, mask)
                    
                elif self.mode == 'mean':
                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat

                    z_mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
                    
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, L, self.d_model).permute(0, 2, 1, 3) # (B, L, N, D)    
            
                    variable_embeddings = torch.mean(x_processed, dim=1)  # (B, N, D)
                
                z = variable_embeddings
                
                # Encoder
                irr_z, _ = self.encoder(variable_embeddings)  # (B, N, D)
                h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                

            elif self.model == "patchtst" or self.model == "timexer" or self.model == "patchmixer" or self.model == "tsmixer":
                if self.mode != "mtand":
                    B, M, L_in, N = X.shape
                    X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
                    mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
                    truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

                    x_input = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
                    time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
                
                if self.mode == 'self':
                    patch_vectors = self.patch_emb(torch.arange(M, dtype=torch.long, device=self.device))
                    patch_tokens = patch_vectors.reshape(1, M, 1, 1, self.d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
                    
                    patch_vs = x_input + time_emb  # (B, M, N, L_in, D)
                    z = torch.cat([patch_tokens, patch_vs], dim=3) # (B, M, N, L_in+1, D)
                    z_flat = z.reshape(B * M * N, L_in + 1, self.d_model)
                    attn_mask = torch.cat([torch.ones(B, M, N, 1, device=self.device), mask], dim=3).reshape(B * M * N, L_in + 1)
                    
                    total_l = z_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=attn_mask.unsqueeze(-1))
                    processed_z = processed_z_flat.reshape(B, M, N, total_l, self.d_model)
                    
                    patch_embeddings = processed_z[:, :, :, 0, :].transpose(1,2)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)  # (B*N, M, D)
                
                elif self.mode == 'mtand':
                    B, M, L_in, N = X.shape
                    X = X.reshape(B, M*L_in, N)
                    truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
                    mask = mask.reshape(B, M*L_in, N)
                    
                    key = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                
                    cls_query = torch.linspace(0, 1., 128)
                    cls_query = self.LearnableTE(cls_query.unsqueeze(0).unsqueeze(-1).to(self.device))
                    patch_embeddings = self.embedding_layer(cls_query, key, X, mask, npatch=M)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)  # (B*N, M, D)
                    
                elif self.mode == 'mean':
                    z = x_input + time_emb  # (B, M, N, L_in, D)
                    z_flat = z.reshape(B * M * N, L_in, self.d_model)
                    attn_mask = mask.reshape(B * M * N, L_in)
                    
                    total_l = z_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_flat, attn_mask=attn_mask.unsqueeze(-1))
                    processed_z = processed_z_flat.reshape(B, M, N, total_l, self.d_model)  # (B, M, N, L_in, D)
                    
                    patch_embeddings = torch.mean(processed_z, dim=-2).transpose(1,2)  # (B, N, M, D)
                    patch_embeddings_flat = patch_embeddings.reshape(B*N, M, self.d_model)  # (B*N, M, D)
                
                z = patch_embeddings.reshape(B, N, M*self.d_model)
                
                # Encoder
                if self.model == "patchtst":
                    irr_z_flat, _ = self.encoder(patch_embeddings_flat) # (B*N, M, D)
                    irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                    h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                    
                elif self.model == "timexer":
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)
                    
                    en_embed = torch.cat([patch_embeddings_flat, var_flat], dim=1)
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)
                    
                    en_embed = torch.cat([patch_embeddings_flat, var_flat], dim=1)
            
                    B, M, N, L_in= X.shape
                    X = X.transpose(-1,-2).reshape(B, M*L_in, N)
                    truth_time_steps = truth_time_steps.transpose(-1,-2).reshape(B, M*L_in, N)
                    mask = mask.transpose(-1,-2).reshape(B, M*L_in, N)
                    
                    B, L, N = X.shape
                    x_input = self.val_emb(X.unsqueeze(-1)) # (B, L, N, D)
                    time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)
                    
                    var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
                    var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
                    var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

                    z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
                    time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
                    z_flat = z_flat + time_emb_flat
                        
                    mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
        
                    z_cat_flat = torch.cat([var_flat, z_flat], dim=1)
                    mask_tokens = torch.ones(B * N, 1, device=self.device)
                    z_mask_flat = torch.cat([mask_tokens, mask_flat], dim=1)
                        
                    total_l = z_cat_flat.shape[1]
                    processed_z_flat = self.embedding_layer(z_cat_flat, attn_mask=z_mask_flat.unsqueeze(-1))
                    x_processed = processed_z_flat.reshape(B, N, total_l, self.d_model).permute(0, 2, 1, 3) # (B, L+1, N, D)    
                        
                    ex_embed = x_processed[:, 0, :, :]  # (B, N, D)
                    
                    irr_z_flat = self.encoder(en_embed, ex_embed) # (B*N, M+1, D)
                    
                    irr_z = irr_z_flat.reshape(B, N, (M+1)*self.d_model) # (B, N, (M+1)*D)
                    h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                
                elif self.model == "patchmixer":
                    patch_embeddings_flat = self.dropout(patch_embeddings_flat)
                    
                    for encoder in self.encoder_blocks:
                        patch_embeddings_flat = encoder(patch_embeddings_flat)
                        
                    irr_z_flat = patch_embeddings_flat  # (B*N, M, D)
                    irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                    h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                    
                elif self.model == "tsmixer":
                    patch_embeddings = patch_embeddings_flat.reshape(B, N, M, self.d_model) # (B, N, M, D)

                    for encoder in self.encoder_blocks:
                        patch_embeddings = encoder(patch_embeddings) # (B, N, M, D)
                    
                    irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
                    irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                    h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                    
        else:
            if self.model == "itransformer" or self.model == "s_mamba":
                B, L, N = X.shape
                
                if self.mode == 'add' or self.mode == 'concat':
                    variable_embeddings = self.embedding_layer(X, truth_time_steps)
                else:
                    variable_embeddings = self.embedding_layer(X, None)
                
                z = variable_embeddings
                irr_z, _ = self.encoder(variable_embeddings)  # (B, N, D)
                h = torch.sum(irr_z, dim=-1).squeeze(-1)  # (B, N)
            
            elif self.model == "patchtst":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                if self.mode == 'add' or self.mode == 'concat':
                    t_enc = truth_time_steps.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                    patch_embeddings_flat = self.embedding_layer(x_enc, t_enc)
                else:
                    patch_embeddings_flat = self.embedding_layer(x_enc)
                    
                z = patch_embeddings_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                irr_z_flat, _ = self.encoder(patch_embeddings_flat) # (B*N, M, D)

                irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
            
            elif self.model == "timexer":
                B, M, L_in, N = X.shape
                
                en_embed = self.en_embedding(X.permute(0, 3, 1, 2))
                ex_embed = self.ex_embedding(X.reshape(B, M*L_in, N), None)
                
                irr_z_flat = self.encoder(en_embed, ex_embed) # (B*N, M, D)
                
                irr_z = irr_z_flat.reshape(B, N, (M+1)*self.d_model) # (B, N, (M+1)*D)
                h = torch.sum(irr_z, dim=-1).squeeze(-1)  # (B, N)
            
            elif self.model == "patchmixer":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)  # (B*N, M, D)
                patch_embeddings_flat = self.dropout(patch_embeddings_flat)

                for encoder in self.encoder_blocks:
                    patch_embeddings_flat = encoder(patch_embeddings_flat)
                    
                irr_z_flat = patch_embeddings_flat  # (B*N, M, D)
                irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
                
            elif self.model == "tsmixer":
                B, M, L_in, N = X.shape
                x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
                
                patch_embeddings_flat = self.embedding_layer(x_enc)  # (B*N, M, D)
                patch_embeddings = patch_embeddings_flat.reshape(B, N, M, self.d_model) # (B, N, M, D)
                
                for encoder in self.encoder_blocks:
                    patch_embeddings = encoder(patch_embeddings) # (B, N, M, D)
                    
                irr_z_flat = patch_embeddings.reshape(B*N, M, self.d_model) # (B*N, M, D)
                irr_z = irr_z_flat.reshape(B, N, M*self.d_model) # (B, N, M*D)
                h = torch.sum(irr_z , dim=-1).squeeze(-1)  # (B, N)
        
        
        # Decoder        
        if feature:
            return z
        
        if P_static is not None:
            static_emb = self.emb(P_static)
            return self.classifier(torch.cat([h, static_emb], dim=-1))
        else:
            return self.classifier(h)