import torch
import torch.nn as nn
from modules import TransformerBlock1, TransformerBlock2, CrossAttentionBlock


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        d_model = args.hid_dim
        self.device = args.device
        self.hid_dim = args.hid_dim
        self.N = args.ndim
        self.n_layer = args.nlayer
        self.npatch = args.npatch
        self.dropout = 0.1
        self.type = args.type
        
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, args.hid_dim - 1)
        
        self.val_emb = nn.Linear(1, d_model)
        self.patch_emb = nn.Embedding(args.npatch, d_model)
        self.var_emb = nn.Embedding(args.ndim, d_model)
        
        self.encoder = nn.ModuleList([
            nn.Sequential(
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock2(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
                ) for _ in range(args.nlayer)
            ])
        self.cross_transformer = CrossAttentionBlock(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
        
        self.decoder = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, 1)
        )

    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)

    
    def patch_extractor(self, X, truth_time_steps, mask=None):
        # 1. 초기 설정 및 데이터 전처리
        B, M, L_in, N = X.shape
        d_model = self.hid_dim

        X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

        # 2. 임베딩 생성
        patch_vs = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
        time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
        
        patch_idx = torch.arange(M, dtype=torch.long, device=self.device)
        patch_reprs = self.patch_emb(patch_idx).view(1, M, 1, 1, d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
        
        var_idx = torch.arange(N, dtype=torch.long, device=self.device)
        var_vectors = self.var_emb(var_idx).view(1, N, 1, d_model)
        cls_reprs = var_vectors.expand(B, -1, 1, -1)  # (B, N, 1, D)

        patch_vs =  patch_vs + time_emb # (B, M, N, L_in, D)
        # 3. 계층적 인코더 (Vectorized)
        for n in range(self.n_layer):
            # === 3.1. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # CLS 토큰 결합
            z = torch.cat([patch_reprs, patch_vs], dim=3) # (B, M, N, L_in+1, D)
        
            # Transformer 입력을 위해 (Batch, Seq_len, Dim) 형태로 변환
            z_flat = z.view(B * M * N, L_in + 1, d_model)
        
            attn_mask = torch.cat([torch.ones(B, M, N, 1, device=self.device), mask], dim=3).view(B * M * N, L_in + 1)
            processed_z_flat = self.encoder[n][0](z_flat, attn_mask=attn_mask.unsqueeze(-1))
            
            # 원래 형태로 복원
            processed_z = processed_z_flat.view(B, M, N, L_in + 1, d_model)
        
            # 업데이트된 값 표현과 패치 CLS 토큰 분리
            patch_reprs_updated = processed_z[:, :, :, 0, :]  # (B, M, N, D)
            patch_vs = processed_z[:, :, :, 1:, :]  # (B, M, N, L_in, D)

            # === 3.2. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # (B, N)개의 모든 패치 시퀀스를 (B*N) 크기의 배치로 만들어 한 번에 처리
            z = patch_reprs_updated.permute(0, 2, 1, 3) # (B, N, M, D)
            z = torch.cat([cls_reprs, z], dim=2) # (B, N, M+1, D)
            
            z_flat = z.reshape(B * N, M + 1, d_model)
        
            processed_z_flat = self.encoder[n][1](z_flat)
            
            processed_z = processed_z_flat.view(B, N, M+1, d_model)
            
            # # 업데이트된 변수 CLS 토큰과 패치 CLS 토큰 분리
            cls_reprs = processed_z[:, :, 0, :] # (B, N, D)
            patch_reprs = processed_z[:, :, 1:, :].permute(0, 2, 1, 3).unsqueeze(3) # (B, M, N, 1, D)
            
            # === 3.3. 변수 레벨 트랜스포머 (Variable-level Transformer) ===
            cls_reprs = self.encoder[n][2](cls_reprs).unsqueeze(2) # (B, N, 1, D)
            
        return cls_reprs, patch_reprs, patch_vs
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        # 1. 입력 데이터의 shape 가져오기
        B, M, L, N = X.shape
        d_model = self.hid_dim
        
        # 2. 3단계 계층적 인코더를 통해 모든 레벨의 정보 추출
        # cls_reprs: 변수 레벨의 전역적 요약 정보 -> (B, N, 1, D)
        # patch_reprs: 패치 레벨의 중간 요약 정보 -> (B, M, N, 1, D)
        cls_reprs, patch_reprs, _ = self.patch_extractor(X, truth_time_steps, mask)
        
        # 3. 예측할 미래 시점에 대한 시간 임베딩(Query) 생성
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)
        te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, d_model)

        # 4. Cross-Attention
        # 4.1. Global Context
        if self.type == 'global':
            global_summary = cls_reprs.squeeze(2) 
            global_summary_flat = global_summary.reshape(B * N, 1, d_model)
            global_context_flat = self.cross_transformer(te_pred_flat, global_summary_flat, global_summary_flat)
            global_context = global_context_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
            outputs = self.decoder(global_context)
        
        # 4.2. Local Details
        elif self.type == 'local':
            local_details_flat = patch_reprs.permute(0, 2, 1, 3, 4).reshape(B * N, M * 1, d_model)
            local_details_flat = self.cross_transformer(te_pred_flat, local_details_flat, local_details_flat)
            local_details = local_details_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
            outputs = self.decoder(local_details)
            
        return outputs.squeeze(-1)