import torch
from torch import nn
from baselines.layers.Transformer_EncDec import SelfAttentionOnlyLayer, CrossAttentionFFNLayer, CrossAttentionOnlyLayer
from baselines.layers.SelfAttention_Family import FullAttention, AttentionLayer

        
class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.mode = args.mode
        self.device = args.device
        self.d_model = args.hid_dim
        self.factor = 5
        
        if self.mode == 'mean':
            self.mean_emb = True
        else:
            self.mean_emb = False
            
        # Embeddings
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, self.d_model - 1)
        
        # IMTS_Token_Embedding
        self.val_emb = nn.Linear(1, self.d_model)
        if self.mode in ['self', 'cross']:
            self.var_emb = nn.Embedding(args.ndim, self.d_model)
        
        attention = AttentionLayer(
                FullAttention(False, self.factor, attention_dropout=args.dropout, output_attention=False), 
                self.d_model, 
                args.nhead
            )
        
        if self.mode in ['self', 'mean']:
            self.embedding_layer = SelfAttentionOnlyLayer(
                attention,
                self.d_model,
                args.dropout
                )
        elif self.mode == 'cross':
            self.embedding_layer = CrossAttentionOnlyLayer(
                attention,
                self.d_model,
                args.dropout
                )
        
        # Decoder
        self.decoder = nn.Sequential(
                nn.Linear(self.d_model*2, self.d_model),
                nn.ReLU(inplace=True),
                nn.Linear(self.d_model, self.d_model),
                nn.ReLU(inplace=True),
                nn.Linear(self.d_model, 1)
            )
        
        
    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        # 1. embeddings
        B, M, L_in, N = X.shape
        X = X.reshape(B, M*L_in, N)
        truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
        mask = mask.reshape(B, M*L_in, N)
                
        B, L, N = X.shape
        x_input = self.val_emb(X.unsqueeze(-1)) # (B, L, N, D)
        time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)

        if self.mode in ['self', 'cross']:
            var_vectors = self.var_emb(torch.arange(N, device=self.device)) # (N, D)
            var_tokens = var_vectors.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
            var_flat = var_tokens.unsqueeze(2).reshape(B * N, 1, self.d_model)

        # imts_variable_embedding
        # (B, L, N, D) -> (B, N, L, D) -> (B*N, L, D)
        z_flat = x_input.permute(0, 2, 1, 3).reshape(B * N, L, self.d_model)    
        time_emb_flat = time_emb.transpose(1,2).reshape(B * N, L, self.d_model)
        z_flat = z_flat + time_emb_flat

        z_mask_flat = mask.permute(0, 2, 1).reshape(B * N, L)
                
        if self.mode == 'self':
            z_cat_flat = torch.cat([var_flat, z_flat], dim=1)
            mask_tokens = torch.ones(B * N, 1, device=self.device)
            z_mask_flat_with_tokens = torch.cat([mask_tokens, z_mask_flat], dim=1)
                    
            processed_z_flat = self.embedding_layer(z_cat_flat, attn_mask=z_mask_flat_with_tokens.unsqueeze(-1))
            
            total_l = z_cat_flat.shape[1]
            x_processed = processed_z_flat.reshape(B, N, total_l, self.d_model).permute(0, 2, 1, 3) # (B, L+1, N, D)    
                    
            irr_z = x_processed[:, 0, :, :].unsqueeze(1)  # (B, 1, N, D) - 변수 토큰 추출
        
        elif self.mode == 'cross':
            processed_z_flat = self.embedding_layer(var_flat, z_flat, attn_mask=z_mask_flat.unsqueeze(-2))
            x_processed = processed_z_flat.reshape(B, N, 1, self.d_model).permute(0, 2, 1, 3) # (B, 1, N, D)    
            irr_z = x_processed # (B, 1, N, D)
            
        elif self.mode == 'mean':
            processed_z_flat = self.embedding_layer(z_flat, attn_mask=z_mask_flat.unsqueeze(-1))
            x_processed = processed_z_flat.reshape(B, N, L, self.d_model).permute(0, 2, 1, 3) # (B, L, N, D)    
            
            irr_z = torch.mean(x_processed, dim=1)  # (B, N, D)
            irr_z = irr_z.unsqueeze(1) # (B, 1, N, D)
                    
        
        # Forecasting
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)
        te_pred = te_pred.unsqueeze(2).expand(-1, -1, N, -1) # (B, L_pred, N, D)
        
        irr_z = irr_z.expand(-1, L_pred, -1, -1)
        h = torch.cat([irr_z, te_pred], dim=-1)
        
        dec_out = self.decoder(h)
                    
        return dec_out.squeeze(-1) # (B, L_pred, N)