import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from baselines.layers.Embed import PatchTST_Embedding_add, PatchTST_Embedding_concat
from modules import TransformerBlock1, TransformerBlock2, CrossAttentionBlock


class multiTimeAttention(nn.Module):
    def __init__(self, input_dim, nq=128, embed_time=16, num_heads=1, npatch=1):
        super(multiTimeAttention, self).__init__()
        assert embed_time % num_heads == 0
        self.embed_time = embed_time
        self.embed_time_k = embed_time // num_heads
        self.h = num_heads
        self.dim = input_dim
        self.nhidden = math.ceil(nq / npatch)
        self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), 
                                      nn.Linear(embed_time, embed_time),
                                      nn.Linear(self.nhidden, self.embed_time_k)])
        
    def attention(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.permute(0, 3, 1, 2).unsqueeze(-3) == 0, -1e9)
        p_attn = F.softmax(scores, dim = -1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.sum(p_attn*value.permute(0, 3, 1, 2).unsqueeze(-3), -1), p_attn
    
    def forward(self, query, key, value, mask=None, dropout=None, npatch=None):
        batch, _, dim = value.size()
        if mask is not None:
            mask = mask.unsqueeze(1)
        value = value.unsqueeze(1)
        query = self.linears[0](query).view(query.size(0), -1, self.h, self.embed_time_k).transpose(1,2).unsqueeze(1)
        key = self.linears[1](key).view(key.size(0), key.size(1), key.size(2), self.h, self.embed_time_k).transpose(1,2).transpose(2,3)
        x, _ = self.attention(query, key, value, mask, dropout)
        if npatch is not None:
            d_k = x.size(-1) 
            
            # 1. 패치당 길이(patch_len) 계산 및 패딩
            patch_len = math.ceil(d_k / npatch)
            total_required_len = patch_len * npatch
            pad_len = total_required_len - d_k
            if pad_len > 0:
                x = F.pad(x, (0, pad_len)) # 마지막 차원(dim)에 0 추가
            
            # 2. 패치 분할: [B, V, H, npatch, patch_len]
            x = x.view(batch, dim, self.h, npatch, patch_len)
            
            # 3. 패치 임베딩
            x = self.linears[-1](x).transpose(2,3)  # [B, V, npatch, H, embed_time_k]
            x = x.reshape(batch, dim, npatch, -1)  # [B, V, npatch, embed_time]
        else:
            x = self.linears[-1](x)
            x = x.reshape(batch, dim, self.h * self.embed_time_k)
        return x


class Hi_Patch(nn.Module):
    def __init__(self, args):
        super(Hi_Patch, self).__init__()
        d_model = args.hid_dim
        self.device = args.device
        self.hid_dim = args.hid_dim
        self.N = args.ndim
        self.n_layer = args.nlayer
        self.npatch = args.npatch
        self.mode = args.mode
        self.dropout = 0.1
        
        self.te_scale = nn.Linear(1, 1)
        self.te_periodic = nn.Linear(1, args.hid_dim - 1)
        
        self.var_emb = nn.Embedding(args.ndim, d_model)
        
        if self.mode == 'add':
            self.embedding_layer = PatchTST_Embedding_add(args.maxlen, self.hid_dim, self.dropout)
        elif self.mode == 'concat':
            self.embedding_layer = PatchTST_Embedding_concat(args.maxlen, self.hid_dim, self.dropout)
        elif self.mode == 'mtand':
            self.embedding_layer = multiTimeAttention(self.N, 128, self.hid_dim, 1, npatch=self.npatch)
        
        self.encoder = nn.ModuleList([
            nn.Sequential(
                TransformerBlock1(d_model, args.nhead, dropout=self.dropout),
                TransformerBlock2(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
                ) for _ in range(args.nlayer)
            ])
        
        self.cross_transformer = CrossAttentionBlock(d_model, args.nhead, 2 * d_model, dropout=self.dropout)
        
        self.decoder = nn.Sequential(
            nn.Linear(2*d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, 1)
        )

    def LearnableTE(self, tt):
        # learnable continuous time embeddings
        out1 = self.te_scale(tt)
        out2 = torch.sin(self.te_periodic(tt))
        return torch.cat([out1, out2], -1)

    
    def patch_extractor(self, X, truth_time_steps, mask=None):
        B, M, L_in, N = X.shape
        d_model = self.hid_dim
        mask = mask.reshape(B, M*L_in, N)

        if self.mode == 'add' or self.mode == 'concat':
            x_enc = X.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
            t_enc = truth_time_steps.permute(0, 3, 1, 2).reshape(B*N, M, L_in)
            patch_reprs = self.embedding_layer(x_enc, t_enc).reshape(B, N, M, d_model)  # (B, N, M, D)                
        else:
            X = X.reshape(B, M*L_in, N)
            truth_time_steps = truth_time_steps.reshape(B, M*L_in, N)
            
            key = self.LearnableTE(truth_time_steps.unsqueeze(-1)) # (B, L, N, D)    
            cls_query = torch.linspace(0, 1., 128)
            cls_query = self.LearnableTE(cls_query.unsqueeze(0).unsqueeze(-1).to(self.device))
            patch_reprs = self.embedding_layer(cls_query, key, X, mask, npatch=M)  # (B, N, M, D)             
        
        var_idx = torch.arange(N, dtype=torch.long, device=self.device)
        var_vectors = self.var_emb(var_idx).view(1, N, 1, d_model)
        cls_reprs = var_vectors.expand(B, -1, 1, -1)  # (B, N, 1, D)
        
        # 3. 계층적 인코더 (Vectorized)
        for n in range(self.n_layer):
            # === 3.1. 패치 레벨 트랜스포머 (Patch-level Transformer) ===
            z = torch.cat([cls_reprs, patch_reprs], dim=2) # (B, N, M+1, D)
            
            z_flat = z.reshape(B * N, M + 1, d_model)
        
            processed_z_flat = self.encoder[n][0](z_flat)
            
            processed_z = processed_z_flat.view(B, N, M+1, d_model)
            
            cls_reprs = processed_z[:, :, 0, :] # (B, N, D)
            patch_reprs = processed_z[:, :, 1:, :] # (B, N, M, D)
            
            # === 3.2. 변수 레벨 트랜스포머 (Variable-level Transformer) ===
            cls_reprs = self.encoder[n][1](cls_reprs).unsqueeze(2) # (B, N, 1, D)
            
        return cls_reprs, patch_reprs
    
    
    def forecasting(self, time_steps_to_predict, X, truth_time_steps, mask=None):
        # 1. 입력 데이터의 shape 가져오기
        B, M, L, N = X.shape
        d_model = self.hid_dim
        
        # 2. 3단계 계층적 인코더를 통해 모든 레벨의 정보 추출
        # cls_reprs: 변수 레벨의 전역적 요약 정보 -> (B, N, 1, D)
        # patch_reprs: 패치 레벨의 중간 요약 정보 -> (B, N, M, D)
        cls_reprs, patch_reprs = self.patch_extractor(X, truth_time_steps, mask)
        
        # 3. 예측할 미래 시점에 대한 시간 임베딩(Query) 생성
        L_pred = time_steps_to_predict.shape[1]
        te_pred = self.LearnableTE(time_steps_to_predict.unsqueeze(-1)) # (B, L_pred, D)
        te_pred_flat = te_pred.unsqueeze(1).expand(-1, N, -1, -1).reshape(B * N, L_pred, d_model)

        # 4. Cross-Attention을 위한 Key, Value 준비
        # 4.1. Global Context
        global_summary = cls_reprs.squeeze(2) 
        global_summary_flat = global_summary.reshape(B * N, 1, d_model)
        
        # 4.2. Local Details
        local_details_flat = patch_reprs.reshape(B * N, M, d_model)
        
        # 5. Cross-Attention 수행
        global_context_flat = self.cross_transformer(te_pred_flat, global_summary_flat, global_summary_flat)
        local_details_flat = self.cross_transformer(te_pred_flat, local_details_flat, local_details_flat)

        # 6. 최종 예측을 위해 shape 복원 및 정보 결합
        global_context = global_context_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
        local_details = local_details_flat.view(B, N, L_pred, d_model).permute(0, 2, 1, 3)
                
        combined_features = torch.cat([global_context, local_details], dim=-1)
        
        outputs = self.decoder(combined_features)
        
        return outputs.squeeze(-1)