import torch
from torch import nn
from layers.Transformer_EncDec import DecoderOnly, DecoderOnlyLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PositionalEmbedding


class Model(nn.Module):

    def __init__(self, configs):
        super().__init__()
        self.task_name = configs.task_name
        self.input_token_len = configs.input_token_len
        self.embedding = nn.Linear(self.input_token_len, configs.bd_model, bias=False)
        self.position_embedding = PositionalEmbedding(configs.bd_model,max_len=6500)
        self.dropout = nn.Dropout(configs.dropout)
        self.blocks = DecoderOnly(
            [
                DecoderOnlyLayer(
                    AttentionLayer(
                        FullAttention(True, attention_dropout=configs.dropout, 
                                      output_attention=False), 
                        configs.bd_model, 
                        configs.bn_heads
                    ),
                    configs.bd_model,
                    configs.bd_ff,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for _ in range(configs.be_layers)
            ],
            norm_layer=nn.LayerNorm(configs.bd_model)
        )
        self.head = nn.Linear(configs.bd_model, configs.output_token_len)
        self.use_norm = configs.use_norm

    def forecast(self, x, x_mark, y_mark):
        if self.use_norm:
            means = x.mean(1, keepdim=True).detach()
            x = x - means
            stdev = torch.sqrt(
                torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x /= stdev
        # [B, L, C]
        B, _, C = x.shape
        # [B, C, L]
        x = x.permute(0, 2, 1)
        # [B, C, N, P]
        x = x.unfold(
            dimension=-1, size=self.input_token_len, step=self.input_token_len)
        N = x.shape[2]
        # [B * C, N, P]
        x = x.reshape(B * C, N, -1)
        # [B * C, N, D]
        embed_out = self.embedding(x) + self.position_embedding(x)
        embed_out = self.dropout(embed_out)
        embed_out, attns = self.blocks(embed_out)
        # [B * C, N, P]
        dec_out = self.head(embed_out)
        # [B, C, L]
        dec_out = dec_out.reshape(B, C, -1)
        # [B, L, C]
        dec_out = dec_out.permute(0, 2, 1)
        if self.use_norm:
            dec_out = dec_out * stdev + means
        return dec_out

    def imputation(self, x_enc, x_mark_enc, x_mark_dec, mask=None):
        # [B, L, C]
        B, L, C = x_enc.shape
        
        # Specialized normalization for imputation (using mask information)
        if self.use_norm:
            # Calculate statistics only on observed values (mask == 1)
            means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
            means = means.unsqueeze(1).detach()
            x_enc = x_enc - means
            # Zero out masked positions after mean subtraction
            x_enc = x_enc.masked_fill(mask == 0, 0)
            # Calculate standard deviation only on observed values
            stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / 
                            torch.sum(mask == 1, dim=1) + 1e-5)
            stdev = stdev.unsqueeze(1).detach()
            x_enc /= stdev
        
        # [B, C, L] - Timer's style of dimension handling
        x_enc = x_enc.permute(0, 2, 1)
        
        # [B, C, N, P] - Unfold into patches like in forecast
        x_enc = x_enc.unfold(
            dimension=-1, size=self.input_token_len, step=self.input_token_len)
        N = x_enc.shape[2]
        
        # [B * C, N, P] - Reshape for parallel processing of features
        x_enc = x_enc.reshape(B * C, N, -1)
        
        # [B * C, N, D] - Apply embeddings (Timer style)
        embed_out = self.embedding(x_enc) + self.position_embedding(x_enc)
        embed_out = self.dropout(embed_out)
        
        # Pass through transformer blocks
        embed_out, attns = self.blocks(embed_out)
        
        # [B * C, N, P] - Project back to time dimension
        dec_out = self.head(embed_out)
        
        # [B, C, L] - Reshape back
        dec_out = dec_out.reshape(B, C, -1)
        
        # [B, L, C] - Return to original dimension ordering
        dec_out = dec_out.permute(0, 2, 1)
        
        # De-normalization
        if self.use_norm:
            dec_out = dec_out * stdev + means
        
        return dec_out

    def forward(self, x, x_mark, y_mark, mask=None):
        if self.task_name == 'long_term_forecast':
            dec_out = self.forecast(x, x_mark, y_mark)
            return dec_out  # [B, T, D]
        elif self.task_name == 'imputation':
            dec_out = self.imputation(x, x_mark, y_mark, mask)
            return dec_out  # [B, T, D]
        return None