import torch
import torch.nn as nn
from modules import TransformerBlock1, TransformerBlock2, CrossAttentionBlock


class DensityFeatureExtractor(nn.Module):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, mask, time_point):
        """
        Args:
            mask: (..., Seq_Len) - 0 or 1
            time_point: (..., Seq_Len) - Normalized time
        Returns:
            features: (..., 2) -> [Intensity, CV]
        """
        L = mask.shape[-1]
        
        # 1. Observation Intensity (rho)
        n_obs = mask.sum(dim=-1, keepdim=True) # (..., 1)
        intensity = n_obs / L

        # 2. Coefficient of Variation (CV)
        # 마스킹 안 된 곳(0)을 정렬 시 뒤로 보내기 위해 큰 값 더함
        invalid_time_penalty = (1 - mask) * 1e9
        masked_time = time_point + invalid_time_penalty
        
        # 오름차순 정렬 (유효한 시간들이 앞으로 모임)
        sorted_time, _ = torch.sort(masked_time, dim=-1) # (..., L)

        # 시간 차분 (Time Interval)
        dt = sorted_time[..., 1:] - sorted_time[..., :-1] # (..., L-1)

        # 유효한 간격 필터링
        # 데이터가 k개 있으면 간격은 k-1개. (n_obs - 1) 인덱스까지만 유효
        idx_range = torch.arange(L - 1, device=mask.device).reshape([1] * (mask.ndim - 1) + [-1])
        valid_dt_count = (n_obs - 1).clamp(min=0)
        dt_mask = (idx_range < valid_dt_count).float() # (..., L-1)

        # 평균 및 분산 계산
        sum_dt = (dt * dt_mask).sum(dim=-1, keepdim=True)
        mean_dt = sum_dt / (valid_dt_count + self.epsilon)

        dt_diff = (dt - mean_dt) * dt_mask
        var_dt = (dt_diff ** 2).sum(dim=-1, keepdim=True) / (valid_dt_count + self.epsilon)
        std_dt = torch.sqrt(var_dt + self.epsilon)

        # CV 계산 (Std / Mean)
        cv = std_dt / (mean_dt + self.epsilon)
        
        # 데이터가 2개 미만인 경우 CV = 0 처리
        cv = cv * (n_obs >= 2).float()

        return torch.cat([intensity, cv], dim=-1)


class Hi_Patch(nn.Module):
    """
    Hi-Patch model for ISMTS forecasting and classification.
    """
    def __init__(self, args, supports=None):
        super(Hi_Patch, self).__init__()
        d_model = args.hid_dim
        self.device = args.device
        self.hid_dim = args.hid_dim
        self.N = args.ndim
        self.n_layer = args.nlayer
        self.npatch = args.npatch
        self.dropout = 0.1
        
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, args.hid_dim - 1)
        
        self.density_extractor = DensityFeatureExtractor()
        self.density_proj = nn.Linear(2, d_model)  # [Intensity, CV] -> D_model
        
        self.val_emb = nn.Linear(1, d_model)
        self.patch_emb = nn.Embedding(args.npatch, d_model)
        self.var_emb = nn.Embedding(args.ndim, d_model)
        
        self.encoder = nn.ModuleList([
            nn.Sequential(
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock2(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
                ) for _ in range(args.nlayer)
            ])
        self.cross_transformer = CrossAttentionBlock(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
        
        if args.task == 'forecasting':
            self.decoder = nn.Sequential(
                nn.Linear(2*d_model, d_model),
                nn.ReLU(inplace=True),
                nn.Linear(d_model, d_model),
                nn.ReLU(inplace=True),
                nn.Linear(d_model, 1)
            )
        else:
            d_static = args.d_static
            if d_static != 0:
                self.emb = nn.Linear(d_static, args.ndim)
                self.classifier = nn.Sequential(
                    nn.Linear(args.ndim * 2, 200),
                    nn.ReLU(),
                    nn.Linear(200, args.n_class))
            else:
                self.classifier = nn.Sequential(
                    nn.Linear(args.ndim, 200),
                    nn.ReLU(),
                    nn.Linear(200, args.n_class))


    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)

    
    def patch_extractor(self, X, truth_time_steps, mask=None):
        # 1. 초기 설정 및 데이터 전처리
        B, M, L_in, N = X.shape
        d_model = self.hid_dim

        X = X.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        mask = mask.permute(0, 1, 3, 2)  # (B, M, N, L_in)
        truth_time_steps = truth_time_steps.permute(0, 1, 3, 2) # (B, M, N, L_in)

        # 2. 임베딩 생성
        patch_vs = self.val_emb(X.unsqueeze(-1)) # (B, M, N, L_in, D)
        time_emb = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, M, N, L_in, D)
        
        dens_feat = self.density_extractor(mask, truth_time_steps) # (B, M, N, 2)       
        # Project to d_model and reshape for addition
        # Target patch_token shape is usually (B, M, N, 1, D)
        patch_reprs = self.density_proj(dens_feat).unsqueeze(-2) # (B, M, N, 1, D)
        
        patch_idx = torch.arange(M, dtype=torch.long, device=self.device)
        patch_id = self.patch_emb(patch_idx).view(1, M, 1, 1, d_model).expand(B, -1, N, -1, -1) # (B, M, N, 1, D)
        patch_reprs = patch_reprs + patch_id
        
        var_idx = torch.arange(N, dtype=torch.long, device=self.device)
        var_vectors = self.var_emb(var_idx).view(1, N, 1, d_model)
        cls_reprs = var_vectors.expand(B, -1, 1, -1)  # (B, N, 1, D)

        patch_vs =  patch_vs + time_emb # (B, M, N, L_in, D)
        # 3. 계층적 인코더 (Vectorized)
        for n in range(self.n_layer):
            # === 3.1. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # CLS 토큰 결합
            z = torch.cat([patch_reprs, patch_vs], dim=3) # (B, M, N, L_in+1, D)
        
            # Transformer 입력을 위해 (Batch, Seq_len, Dim) 형태로 변환
            z_flat = z.view(B * M * N, L_in + 1, d_model)
        
            attn_mask = torch.cat([torch.ones(B, M, N, 1, device=self.device), mask], dim=3).view(B * M * N, L_in + 1)
            processed_z_flat = self.encoder[n][0](z_flat, attn_mask=attn_mask.unsqueeze(-1))
            
            # 원래 형태로 복원
            processed_z = processed_z_flat.view(B, M, N, L_in + 1, d_model)
        
            # 업데이트된 값 표현과 패치 CLS 토큰 분리
            patch_reprs_updated = processed_z[:, :, :, 0, :]  # (B, M, N, D)
            patch_vs = processed_z[:, :, :, 1:, :]  # (B, M, N, L_in, D)

            # === 3.2. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            # (B, N)개의 모든 패치 시퀀스를 (B*N) 크기의 배치로 만들어 한 번에 처리
            z = patch_reprs_updated.permute(0, 2, 1, 3) # (B, N, M, D)
            z = torch.cat([cls_reprs, z], dim=2) # (B, N, M+1, D)
            
            z_flat = z.reshape(B * N, M + 1, d_model)
        
            processed_z_flat = self.encoder[n][1](z_flat)
            
            processed_z = processed_z_flat.view(B, N, M+1, d_model)
            
            # # 업데이트된 변수 CLS 토큰과 패치 CLS 토큰 분리
            cls_reprs = processed_z[:, :, 0, :] # (B, N, D)
            patch_reprs = processed_z[:, :, 1:, :].permute(0, 2, 1, 3).unsqueeze(3) # (B, M, N, 1, D)
            
            # === 3.3. 변수 레벨 트랜스포머 (Variable-level Transformer) ===
            cls_reprs = self.encoder[n][2](cls_reprs).unsqueeze(2) # (B, N, 1, D)
            
        return cls_reprs, patch_reprs, patch_vs
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        # 1. 입력 데이터의 shape 가져오기
        B, M, L, N = X.shape
        d_model = self.hid_dim
        
        # 2. 3단계 계층적 인코더를 통해 모든 레벨의 정보 추출
        # cls_reprs: 변수 레벨의 전역적 요약 정보 -> (B, N, 1, D)
        # patch_reprs: 패치 레벨의 중간 요약 정보 -> (B, M, N, 1, D)
        cls_reprs, patch_reprs, _ = self.patch_extractor(X, truth_time_steps, mask)
        
        # 3. 예측할 미래 시점에 대한 시간 임베딩(Query) 생성
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)
        te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, d_model)

        # 4. Cross-Attention을 위한 Key, Value 준비
        # 4.1. Global Context
        global_summary = cls_reprs.squeeze(2) 
        global_summary_flat = global_summary.reshape(B * N, 1, d_model)
        
        # 4.2. Local Details
        local_details_flat = patch_reprs.permute(0, 2, 1, 3, 4).reshape(B * N, M * 1, d_model)
        
        # 5. Cross-Attention 수행
        global_context_flat = self.cross_transformer(te_pred_flat, global_summary_flat, global_summary_flat)
        local_details_flat = self.cross_transformer(te_pred_flat, local_details_flat, local_details_flat)

        # 6. 최종 예측을 위해 shape 복원 및 정보 결합
        global_context = global_context_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
        local_details = local_details_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
                
        combined_features = torch.cat([global_context, local_details], dim=-1)
        
        outputs = self.decoder(combined_features)
        
        return outputs.squeeze(-1)