import torch
import torch.nn as nn
from modules import  CrossAttention, TransformerBlock1, TransformerBlock2, CrossAttentionBlock


class Hi_Patch(nn.Module):
    def __init__(self, args, supports=None):
        super(Hi_Patch, self).__init__()
        d_model = args.hid_dim
        self.device = args.device
        self.hid_dim = args.hid_dim
        self.N = args.ndim
        self.n_layer = args.nlayer
        self.npatch = args.npatch
        self.dropout = 0.1
        
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, args.hid_dim - 1)
        
        self.val_emb = nn.Linear(1, d_model)
        self.patch_emb = nn.Embedding(args.npatch, d_model)
        self.var_emb = nn.Embedding(args.ndim, d_model)
        
        self.encoder = nn.ModuleList([
            nn.Sequential(
                CrossAttention(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock2(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
                ) for _ in range(args.nlayer)
            ])
        
        self.cross_transformer = CrossAttentionBlock(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
        
        self.decoder = nn.Sequential(
            nn.Linear(2*d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, 1)
        )


    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)

    
    def patch_extractor(self, X, truth_time_steps, mask=None):
        # 1. 초기 설정 및 데이터 전처리
        B, M, L_in, N = X.shape
        d_model = self.hid_dim

        X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

        # 2. 임베딩 생성
        patch_vs = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
        time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
        
        patch_idx = torch.arange(M, dtype=torch.long, device=self.device)
        patch_reprs = self.patch_emb(patch_idx).reshape(1, M, 1, 1, d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
        
        var_idx = torch.arange(N, dtype=torch.long, device=self.device)
        var_vectors = self.var_emb(var_idx).reshape(1, N, 1, d_model)
        cls_reprs = var_vectors.expand(B, -1, 1, -1)  # (B, N, 1, D)
        
        patch_vs =  patch_vs + time_emb # (B, M, N, L_in, D)
        patch_vs_flat = patch_vs.reshape(B * M * N, L_in, d_model)
        
        mask_flat = mask.reshape(B * M * N, L_in).unsqueeze(-2)
        
        # 3. 계층적 인코더 (Vectorized)
        for n in range(self.n_layer):
            patch_reprs_flat = patch_reprs.reshape(B * M * N, 1, d_model)
            
            # === 3.1. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            processed_z_flat = self.encoder[n][0](patch_reprs_flat, patch_vs_flat, patch_vs_flat, mask=mask_flat)
            processed_z = processed_z_flat.reshape(B, M, N, d_model)
            
            # === 3.2. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # (B, N)개의 모든 패치 시퀀스를 (B*N) 크기의 배치로 만들어 한 번에 처리
            z = processed_z.permute(0, 2, 1, 3) # (B, N, M, D)
            z = torch.cat([cls_reprs, z], dim=2) # (B, N, M+1, D)
            z_flat = z.reshape(B * N, M + 1, d_model)
        
            processed_z_flat = self.encoder[n][1](z_flat)
            processed_z = processed_z_flat.reshape(B, N, M+1, d_model)
            
            cls_reprs = processed_z[:, :, 0, :] # (B, N, D)
            patch_reprs = processed_z[:, :, 1:, :].permute(0, 2, 1, 3).unsqueeze(3) # (B, M, N, 1, D)
            
            # === 3.3. 변수 레벨 트랜스포머 (Variable-level Transformer) ===
            cls_reprs = self.encoder[n][2](cls_reprs).unsqueeze(2) # (B, N, 1, D)
        
        return cls_reprs, patch_reprs
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        # 1. 입력 데이터의 shape 가져오기
        B, M, L, N = X.shape
        d_model = self.hid_dim
        
        # 2. 3단계 계층적 인코더를 통해 모든 레벨의 정보 추출
        # cls_reprs: (B, N, 1, D)
        # patch_reprs: (B, M, N, 1, D)
        cls_reprs, patch_reprs = self.patch_extractor(X, truth_time_steps, mask)
        
        # 3. 예측할 미래 시점에 대한 시간 임베딩(Query) 생성
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)

        # 4. Cross-Attention을 위한 Q, K, V 준비 (B*N 배치로 효율화)
        
        # 4.1. Query: (B, L_pred, D) -> (B, 1, L_pred, D) -> (B, N, L_pred, D) -> (B*N, L_pred, D)
        # expand는 메모리를 복사하지 않고 뷰(view)만 생성하므로 매우 효율적입니다.
        te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).contiguous().view(B * N, L_pred, d_model)

        # 4.2. Global K, V: (B, N, 1, D) -> (B*N, 1, D)
        # .squeeze()와 .reshape()를 한 번의 .view()로 통합합니다.
        global_summary_flat = cls_reprs.contiguous().view(B * N, 1, d_model)
        
        # 4.3. Local K, V: (B, M, N, 1, D) -> (B, N, M, 1, D) -> (B*N, M, D)
        # permute 후 .view()를 사용하여 B*N 배치로 만듭니다. (M*1 -> M으로 정리)
        local_details_flat = patch_reprs.permute(0, 2, 1, 3, 4).contiguous().view(B * N, M, d_model)
        
        # 5. Cross-Attention 수행
        # 아키텍처상 Global과 Local 정보는 별도로 처리 후 결합하는 것이 맞습니다.
        global_context_flat = self.cross_transformer(te_pred_flat, global_summary_flat, global_summary_flat)
        local_details_flat = self.cross_transformer(te_pred_flat, local_details_flat, local_details_flat)

        # 6. 최종 예측을 위해 shape 복원 및 정보 결합
        # (B*N, L_pred, D) -> (B, N, L_pred, D) -> (B, L_pred, N, D)
        global_context = global_context_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
        local_details = local_details_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
                
        combined_features = torch.cat([global_context, local_details], dim=-1)
        
        outputs = self.decoder(combined_features)
        
        return outputs.squeeze(-1)