import torch
import torch.nn as nn
from modules import TransformerBlock1, TransformerBlock2


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()
        d_model = args.hid_dim
        self.device = args.device
        self.hid_dim = args.hid_dim
        self.N = args.ndim
        self.n_layer = args.nlayer
        self.npatch = args.npatch
        self.dropout = args.dropout
        
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, args.hid_dim - 1)
        
        self.val_emb = nn.Linear(1, d_model)
        self.patch_emb = nn.Embedding(args.npatch, d_model)
        self.var_emb = nn.Embedding(args.ndim, d_model)
        
        self.encoder = nn.ModuleList([
            nn.Sequential(
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock2(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
                ) for _ in range(args.nlayer)
            ])
        
        d_static = args.d_static
        if d_static != 0:
            self.emb = nn.Linear(d_static, d_model)
            self.classifier = nn.Sequential(
                nn.Linear(args.ndim*d_model + d_model, 300),
                nn.ReLU(),
                nn.Linear(300, 300),
                nn.ReLU(),
                nn.Linear(300, args.n_class))
        else:
            self.classifier = nn.Sequential(
                nn.Linear(args.ndim*d_model, 300),
                nn.ReLU(),
                nn.Linear(300, 300),
                nn.ReLU(),
                nn.Linear(300, args.n_class))


    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)

    
    def patch_extractor(self, X, truth_time_steps, mask=None):
        # 1. 초기 설정 및 데이터 전처리
        B, M, L_in, N = X.shape
        d_model = self.hid_dim

        X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

        # 2. 임베딩 생성
        patch_vs = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
        time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
        
        patch_idx = torch.arange(M, dtype=torch.long, device=self.device)
        patch_reprs = self.patch_emb(patch_idx).view(1, M, 1, 1, d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
        
        var_idx = torch.arange(N, dtype=torch.long, device=self.device)
        var_vectors = self.var_emb(var_idx).view(1, N, 1, d_model)
        cls_reprs = var_vectors.expand(B, -1, 1, -1)  # (B, N, 1, D)

        patch_vs =  patch_vs + time_emb # (B, M, N, L_in, D)
        # 3. 계층적 인코더 (Vectorized)
        for n in range(self.n_layer):
            # === 3.1. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # CLS 토큰 결합
            z = torch.cat([patch_reprs, patch_vs], dim=3) # (B, M, N, L_in+1, D)
        
            # Transformer 입력을 위해 (Batch, Seq_len, Dim) 형태로 변환
            z_flat = z.view(B * M * N, L_in + 1, d_model)
        
            attn_mask = torch.cat([torch.ones(B, M, N, 1, device=self.device), mask], dim=3).view(B * M * N, L_in + 1)
            processed_z_flat = self.encoder[n][0](z_flat, attn_mask=attn_mask.unsqueeze(-1))
            
            # 원래 형태로 복원
            processed_z = processed_z_flat.view(B, M, N, L_in + 1, d_model)
        
            # 업데이트된 값 표현과 패치 CLS 토큰 분리
            patch_reprs_updated = processed_z[:, :, :, 0, :]  # (B, M, N, D)
            patch_vs = processed_z[:, :, :, 1:, :]  # (B, M, N, L_in, D)

            # === 3.2. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # (B, N)개의 모든 패치 시퀀스를 (B*N) 크기의 배치로 만들어 한 번에 처리
            z = patch_reprs_updated.permute(0, 2, 1, 3) # (B, N, M, D)
            z = torch.cat([cls_reprs, z], dim=2) # (B, N, M+1, D)
            
            z_flat = z.reshape(B * N, M + 1, d_model)
        
            processed_z_flat = self.encoder[n][1](z_flat)
            
            processed_z = processed_z_flat.view(B, N, M+1, d_model)
            
            # # 업데이트된 변수 CLS 토큰과 패치 CLS 토큰 분리
            cls_reprs = processed_z[:, :, 0, :] # (B, N, D)
            patch_reprs = processed_z[:, :, 1:, :].permute(0, 2, 1, 3).unsqueeze(3) # (B, M, N, 1, D)
            
            # === 3.3. 변수 레벨 트랜스포머 (Variable-level Transformer) ===
            cls_reprs = self.encoder[n][2](cls_reprs).unsqueeze(2) # (B, N, 1, D)
            
        return cls_reprs, patch_reprs, patch_vs
    
    
    def classification(self, X, truth_time_steps, mask=None, P_static=None, feature=False):
        cls_reprs, patch_reprs, _ = self.patch_extractor(X, truth_time_steps, mask)
        B, M, N, _, D = patch_reprs.shape
        h = cls_reprs.squeeze(-2).reshape(B, N*D)
        
        # B, M, N, _, D = patch_reprs.shape
        # h = torch.cat([cls_reprs, patch_reprs.squeeze(-2).permute(0,2,1,3)], dim=-2).reshape(B, N, (M+1)*D)
        # h = torch.mean(h, dim=-1).squeeze(-1)
        
        if feature:
            return h
        
        if P_static is not None:
            static_emb = self.emb(P_static)
            return self.classifier(torch.cat([h, static_emb], dim=-1))
        else:
            return self.classifier(h)