# Cross-Block Imputation Transformer
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# ---------------- helpers: renorm & sharpen ----------------

def renorm_block(A_block: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """Row-normalize an attention block extracted from global softmax."""
    return A_block / (A_block.sum(dim=-1, keepdim=True) + eps)


def topk_sharpen(A: torch.Tensor, topk_ratio=0.5, eps: float = 1e-8) -> torch.Tensor:
    """
    Top-k sharpening after block renorm.
    - For each row, keep top-k entries (k = ceil(sqrt(N))) and zero the rest.
    - Renormalize each row again to sum to 1.
    """
    N = A.size(-1)
    k = int(N ** topk_ratio) + 1
    if k >= N:
        return A

    # (B, H, N, k)
    _, idx = A.topk(k, dim=-1)
    mask = torch.zeros_like(A).scatter_(-1, idx, 1.0)
    A = A * mask
    A = A / (A.sum(dim=-1, keepdim=True) + eps)
    return A


def poly_sharpen(A: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Polynomial 2nd-order sharpening after block renorm:
        A' = A * (1 + alpha * A) = A + alpha * A^2
    Then renormalize per row.
    """
    A = A * (1.0 + alpha * A)
    A = A / (A.sum(dim=-1, keepdim=True) + eps)
    return A

class DMHStageCross(nn.Module):
    """
    Dual-Manifold Multi-Head Imputation Attention (Cross-block version).

    Streams inside this stage:
      - Z_input: original tokens Z = [L; Xt]                  (B, N, D)
      - Z_full : global full-attention output A_full @ V      (B, N, D)
      - Z_blk  : sum of four block-wise outputs after W      (B, N, D)

    Outputs:
      - L_out : latent manifold output,  (B, M, D)
      - X_out : current-time manifold output, (B, C, D)
      - Z_next: unified stream = Z_full_tok + Z_blk_tok, (B, N, D)
                (you will later do [L_out, X_out] + Z_next outside this stage)
    """

    def __init__(
        self,
        model_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        use_gate: bool = True,
        sharpen_type: str = "none",  # "none" | "topk" | "poly"
        poly_alpha_init: float = 0.2,
        topk_ratio: float = 0.5
    ):
        super().__init__()
        assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"

        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.use_gate = use_gate
        self.sharpen_type = sharpen_type

        # Shared Q/K/V projections for Z = [L; Xt]
        self.q_proj = nn.Linear(model_dim, model_dim, bias=False)
        self.k_proj = nn.Linear(model_dim, model_dim, bias=False)
        self.v_proj = nn.Linear(model_dim, model_dim, bias=False)

        # Manifold-level output projections
        self.out_L = nn.Linear(model_dim, model_dim, bias=False)
        self.out_X = nn.Linear(model_dim, model_dim, bias=False)

        # Block-wise projections (each attention block has its own W)
        # For latent side: LL and LX
        self.proj_L_LL = nn.Linear(model_dim, model_dim, bias=False)
        self.proj_L_LX = nn.Linear(model_dim, model_dim, bias=False)
        # For Xt side: XX and XL
        self.proj_X_XX = nn.Linear(model_dim, model_dim, bias=False)
        self.proj_X_XL = nn.Linear(model_dim, model_dim, bias=False)

        self.attn_weight = nn.Parameter(torch.tensor(0.8))

        # Single global learnable alpha for polynomial sharpening
        if sharpen_type == "poly":
            self.alpha = nn.Parameter(torch.tensor(poly_alpha_init, dtype=torch.float32))
        else:
            self.alpha = None

        self.topk_ratio = topk_ratio

        # Gate for injecting latent global signal into Xt
        if use_gate:
            self.gate = nn.Linear(2 * model_dim, 1)

        self.sdpa_gate_proj = nn.Linear(model_dim, num_heads, bias=True)

        self.dropout = nn.Dropout(dropout)
        self.norm_L = nn.LayerNorm(model_dim)
        self.norm_X = nn.LayerNorm(model_dim)
        self.z_norm = nn.LayerNorm(model_dim)

        # Dedicated V projections for each block
        self.v_proj_LL = nn.Linear(model_dim, model_dim, bias=False)  # L->L专用
        self.v_proj_LX = nn.Linear(model_dim, model_dim, bias=False)  # L->X专用
        self.v_proj_XL = nn.Linear(model_dim, model_dim, bias=False)  # X->L专用
        self.v_proj_XX = nn.Linear(model_dim, model_dim, bias=False)  # X->X专用

        # # Block attention biases
        # self.attn_bias_LL = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        # self.attn_bias_LX = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        # self.attn_bias_XL = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        # self.attn_bias_XX = nn.Parameter(torch.zeros(1, num_heads, 1, 1))

        # Fusion gates for intelligent combination (avoid hard expansion)
        self.latent_fusion_gate = nn.Linear(model_dim, 1)
        self.current_fusion_gate = nn.Linear(model_dim, 1)
        self.z_fusion_gate = nn.Linear(model_dim, 1)

        self.latent_maxpool = nn.AdaptiveMaxPool1d(1)   # 用于 L 侧
        self.current_avgpool = nn.AdaptiveAvgPool1d(1)  # 用于 X 侧
    # --------------- head helpers ---------------

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, N, D) -> (B, H, N, d)"""
        B, N, D = x.shape
        return x.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, H, N, d) -> (B, N, D)"""
        B, H, N, d = x.shape
        return x.transpose(1, 2).contiguous().view(B, N, H * d)

    # --------------- forward ---------------

    def forward(self, L: torch.Tensor, Xt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          L : (B, M, D)  latent tokens (history manifold)
          Xt: (B, C, D)  current-time tokens (observation manifold)

        Returns:
          L_out : (B, M, D)
          X_out : (B, C, D)
          Z_next: (B, M+C, D) unified stream = Z_full_tok + Z_blk_tok
        """
        B, M, D = L.shape
        _, C, _ = Xt.shape
        N = M + C

        # (1) Concatenate input tokens
        Z_input = torch.cat([L, Xt], dim=1)  # (B, M+C, D)

        # (2) Shared multi-head Q/K/V
        Q = self._split_heads(self.q_proj(Z_input))
        K = self._split_heads(self.k_proj(Z_input))
        V = self._split_heads(self.v_proj(Z_input))

        # (3) Global attention scores and softmax with block biases
        S = torch.matmul(Q, K.transpose(-1, -2)) * self.scale
        
        # # Add block-specific biases
        # bias_matrix = torch.zeros_like(S)
        # bias_matrix[:, :, :M, :M] = self.attn_bias_LL
        # bias_matrix[:, :, :M, M:] = self.attn_bias_LX
        # bias_matrix[:, :, M:, :M] = self.attn_bias_XL
        # bias_matrix[:, :, M:, M:] = self.attn_bias_XX
        
        # S = S + bias_matrix
        A_full = F.softmax(S, dim=-1)

        torch.save(A_full.detach().cpu(), "saved_weights/A_weights.pt")

        # (4) Global full-attention output
        Z_full = torch.matmul(A_full, V)

        socres = A_full.max(axis=1).values - A_full.mean(axis=1) 
        torch.save(Z_full.detach().cpu(), "saved_weights/scores_weights.pt")

        gate_logits = self.sdpa_gate_proj(Z_input)
        gate = torch.sigmoid(gate_logits).permute(0, 2, 1).unsqueeze(-1)
        Z_full = Z_full * gate

        # Split global output
        L_full = Z_full[:, :, :M, :]
        X_full = Z_full[:, :, M:, :]

        # (5) Split A_full into four blocks
        A_LL = A_full[:, :, :M, :M]
        A_LX = A_full[:, :, :M, M:]
        A_XL = A_full[:, :, M:, :M]
        A_XX = A_full[:, :, M:, M:]

        # (6) Dedicated V projections for each block
        V_LL = self._split_heads(self.v_proj_LL(L))
        V_LX = self._split_heads(self.v_proj_LX(Xt))
        V_XL = self._split_heads(self.v_proj_XL(L))
        V_XX = self._split_heads(self.v_proj_XX(Xt))

        # Renormalize each block (row-wise)
        A_LL = renorm_block(A_LL)
        A_LX = renorm_block(A_LX)
        A_XL = renorm_block(A_XL)
        A_XX = renorm_block(A_XX)

        # Optional sharpening (no extra softmax)
        if self.sharpen_type == "topk":
            A_LL = topk_sharpen(A_LL, self.topk_ratio)
            A_LX = topk_sharpen(A_LX, self.topk_ratio)
            A_XL = topk_sharpen(A_XL, self.topk_ratio)
            A_XX = topk_sharpen(A_XX, self.topk_ratio)
        elif self.sharpen_type == "poly":
            effective_alpha = F.softplus(self.alpha)
            A_LL = poly_sharpen(A_LL, effective_alpha)
            A_LX = poly_sharpen(A_LX, effective_alpha)
            A_XL = poly_sharpen(A_XL, effective_alpha)
            A_XX = poly_sharpen(A_XX, effective_alpha)

        A_LL = self.dropout(A_LL)
        A_LX = self.dropout(A_LX)
        A_XL = self.dropout(A_XL)
        A_XX = self.dropout(A_XX)

        # (7) Cross-block AV for four blocks
        L_ll = torch.matmul(A_LL, V_LL)  # L -> L
        L_lx = torch.matmul(A_LX, V_LX)  # L -> X
        X_xx = torch.matmul(A_XX, V_XX)  # X -> X
        X_xl = torch.matmul(A_XL, V_XL)  # X -> L

        # (8) Merge heads to token space
        L_ll_tok = self._merge_heads(L_ll)  # (B, M, D)
        L_lx_tok = self._merge_heads(L_lx)  # (B, M, D)
        X_xx_tok = self._merge_heads(X_xx)  # (B, C, D)
        X_xl_tok = self._merge_heads(X_xl)  # (B, C, D)

        # (9) Block-wise linear projections
        L_ll_tok = self.proj_L_LL(L_ll_tok)
        L_lx_tok = self.proj_L_LX(L_lx_tok)
        X_xx_tok = self.proj_X_XX(X_xx_tok)
        X_xl_tok = self.proj_X_XL(X_xl_tok)

        # (10) Main outputs from block attention (preserve positional information)
        L_main = L_ll_tok + L_lx_tok  # (B, M, D)
        X_main = X_xx_tok + X_xl_tok  # (B, C, D)

        # (11) Global summary from all components
        # L_full_tok = self._merge_heads(L_full)  # (B, M, D)
        # X_full_tok = self._merge_heads(X_full)  # (B, C, D)
        Z_full_tok = self._merge_heads(Z_full)  # (B, N, D)
        

        L_ll_summary = self.latent_maxpool(L_ll_tok.transpose(1, 2)).transpose(1, 2)
        L_lx_summary = self.latent_maxpool(L_lx_tok.transpose(1, 2)).transpose(1, 2)
        # X 侧：平均池化（平滑当前局部依赖）
        X_xx_summary = self.current_avgpool(X_xx_tok.transpose(1, 2)).transpose(1, 2)
        X_xl_summary = self.current_avgpool(X_xl_tok.transpose(1, 2)).transpose(1, 2)
         # Global fused representation
        global_fused = L_ll_summary+L_lx_summary+X_xx_summary+X_xl_summary

#        # (11) Stripe-level pooling from main responses
#        g_hist = self.latent_maxpool(L_main.transpose(1, 2)).transpose(1, 2)  # (B, 1, D)
#        g_cur  = self.current_avgpool(X_main.transpose(1, 2)).transpose(1, 2) # (B, 1, D)
#
#        # (12) Stage-level global semantic anchor
#        global_fused = g_hist + g_cur  

        # (12) Intelligent fusion: combine block outputs with global information
        # Avoid hard expansion by using learnable fusion gates
        
        # For latent manifold
        global_L = global_fused.expand(-1, M, -1)  # (B, M, D)
        fusion_weight_L = torch.sigmoid(self.latent_fusion_gate(L_main))  # (B, M, 1)
        L_fused = fusion_weight_L * L_main + (1 - fusion_weight_L) * global_L  # (B, M, D)
        
        # L_fused = global_L

        # For current manifold  
        global_X = global_fused.expand(-1, C, -1)  # (B, C, D)
        fusion_weight_X = torch.sigmoid(self.current_fusion_gate(X_main))  # (B, C, 1)
        X_fused = fusion_weight_X * X_main + (1 - fusion_weight_X) * global_X  # (B, C, D)

        # X_fused = global_X


        # import os
        # save_dir = "./saved_weights"
        # os.makedirs(save_dir, exist_ok=True)
        
        # torch.save(L_ll_tok.detach().cpu(), f"{save_dir}/L_ll_tok.pth")
        # torch.save(L_lx_tok.detach().cpu(), f"{save_dir}/L_lx_tok.pth") 
        # torch.save(X_xx_tok.detach().cpu(), f"{save_dir}/X_xx_tok.pth")
        # torch.save(X_xl_tok.detach().cpu(), f"{save_dir}/X_xl_tok.pth")
        # torch.save(L_ll_summary.detach().cpu(), f"{save_dir}/L_ll_summary.pth")
        # torch.save(X_xx_summary.detach().cpu(), f"{save_dir}/X_xx_summary.pth")
        # torch.save(global_fused.detach().cpu(), f"{save_dir}/global_fused.pth")



        # (13) Residual connections + LayerNorm
        L_out = self.norm_L(L + self.dropout(L_fused))
        X_out = self.norm_X(Xt + self.dropout(X_fused))

        # (14) Build Z_next with intelligent fusion
        Z_main = torch.cat([L_main, X_main], dim=1)  # (B, N, D)
        Z_global = global_fused.expand(-1, N, -1)    # (B, N, D)
        fusion_weight_Z = torch.sigmoid(self.z_fusion_gate(Z_main))  # (B, N, 1)
        Z_fused = fusion_weight_Z * Z_main + (1 - fusion_weight_Z) * Z_global  # (B, N, D)
        # Z_next = self.z_norm(Z_input + self.dropout(Z_fused))
        Z_next = self.z_norm(Z_fused+Z_full_tok)

        return L_out, X_out, Z_next

# ---------------- DMHStageCross ----------------

class DMHStageCross_old(nn.Module):
    """
    Dual-Manifold Multi-Head Imputation Attention (Cross-block version).

    Streams inside this stage:
      - Z_input: original tokens Z = [L; Xt]                  (B, N, D)
      - Z_full : global full-attention output A_full @ V      (B, N, D)
      - Z_blk  : sum of four block-wise outputs after W      (B, N, D)

    Outputs:
      - L_out : latent manifold output,  (B, M, D)
      - X_out : current-time manifold output, (B, C, D)
      - Z_next: unified stream = Z_full_tok + Z_blk_tok, (B, N, D)
                (you will later do [L_out, X_out] + Z_next outside this stage)
    """

    def __init__(
        self,
        model_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        use_gate: bool = True,
        sharpen_type: str = "none",  # "none" | "topk" | "poly"
        poly_alpha_init: float = 0.2,
        topk_ratio:float = 0.5
    ):
        super().__init__()
        assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"

        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.use_gate = use_gate
        self.sharpen_type = sharpen_type

        # Shared Q/K/V projections for Z = [L; Xt]
        self.q_proj = nn.Linear(model_dim, model_dim, bias=False)
        self.k_proj = nn.Linear(model_dim, model_dim, bias=False)
        self.v_proj = nn.Linear(model_dim, model_dim, bias=False)

        # Manifold-level output projections
        self.out_L = nn.Linear(model_dim, model_dim, bias=False)
        self.out_X = nn.Linear(model_dim, model_dim, bias=False)

        # Block-wise projections (each attention block has its own W)
        # For latent side: LL and LX
        self.proj_L_LL = nn.Linear(model_dim, model_dim, bias=False)
        self.proj_L_LX = nn.Linear(model_dim, model_dim, bias=False)
        # For Xt side: XX and XL
        self.proj_X_XX = nn.Linear(model_dim, model_dim, bias=False)
        self.proj_X_XL = nn.Linear(model_dim, model_dim, bias=False)

        # self.attn_weight = nn.Parameter(torch.tensor(0.8))

        # Single global learnable alpha for polynomial sharpening
        if sharpen_type == "poly":
            self.alpha = nn.Parameter(torch.tensor(poly_alpha_init, dtype=torch.float32))
        else:
            self.alpha = None

        self.topk_ratio = topk_ratio

        # Gate for injecting latent global signal into Xt
        if use_gate:
            self.gate = nn.Linear(2 * model_dim, 1)

        self.sdpa_gate_proj = nn.Linear(model_dim, num_heads, bias=True)

        self.dropout = nn.Dropout(dropout)
        self.norm_L = nn.LayerNorm(model_dim)
        self.norm_X = nn.LayerNorm(model_dim)
        self.z_norm = nn.LayerNorm(model_dim)

        self.v_proj_LL = nn.Linear(model_dim, model_dim, bias=False)  # L->L专用
        self.v_proj_LX = nn.Linear(model_dim, model_dim, bias=False)  # L->X专用
        self.v_proj_XL = nn.Linear(model_dim, model_dim, bias=False)  # X->L专用
        self.v_proj_XX = nn.Linear(model_dim, model_dim, bias=False)  # X->X专用

        self.attn_bias_LL = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        self.attn_bias_LX = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        self.attn_bias_XL = nn.Parameter(torch.zeros(1, num_heads, 1, 1))
        self.attn_bias_XX = nn.Parameter(torch.zeros(1, num_heads, 1, 1))

        # 在 __init__ 中添加：
        # 可学习的块特定权重矩阵（不是标量！）
        self.attn_matrix_LL = nn.Parameter(torch.eye(num_heads))  # (H, H)
        self.attn_matrix_LX = nn.Parameter(torch.eye(num_heads))  # (H, H)  
        self.attn_matrix_XL = nn.Parameter(torch.eye(num_heads))  # (H, H)
        self.attn_matrix_XX = nn.Parameter(torch.eye(num_heads))  # (H, H)

    # --------------- head helpers ---------------

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, N, D) -> (B, H, N, d)"""
        B, N, D = x.shape
        return x.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, H, N, d) -> (B, N, D)"""
        B, H, N, d = x.shape
        return x.transpose(1, 2).contiguous().view(B, N, H * d)

    # --------------- forward ---------------

    def forward(self, L: torch.Tensor, Xt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
          L : (B, M, D)  latent tokens (history manifold)
          Xt: (B, C, D)  current-time tokens (observation manifold)

        Returns:
          L_out : (B, M, D)
          X_out : (B, C, D)
          Z_next: (B, M+C, D) unified stream = Z_full_tok + Z_blk_tok
        """
        B, M, D = L.shape
        _, C, _ = Xt.shape
        N = M + C

        # (1) Concatenate input tokens
        # Z_input: (B, N, D)
        Z_input = torch.cat([L, Xt], dim=1)  # (B, M+C, D)

        # (2) Shared multi-head Q/K/V
        # Q, K, V: (B, H, N, d)
        Q = self._split_heads(self.q_proj(Z_input))
        K = self._split_heads(self.k_proj(Z_input))
        V = self._split_heads(self.v_proj(Z_input))



        # (3) Global attention scores and softmax
        # S, A_full: (B, H, N, N)
        S = torch.matmul(Q, K.transpose(-1, -2)) * self.scale
        # A_full = F.softmax(S, dim=-1)


         # 为每个块区域添加对应的偏置
        bias_matrix = torch.zeros_like(S)
        bias_matrix[:, :, :M, :M] = self.attn_bias_LL  # LL区域
        bias_matrix[:, :, :M, M:] = self.attn_bias_LX  # LX区域  
        bias_matrix[:, :, M:, :M] = self.attn_bias_XL  # XL区域
        bias_matrix[:, :, M:, M:] = self.attn_bias_XX  # XX区域
        
        # 应用偏置
        S = S + bias_matrix
        A_full = F.softmax(S, dim=-1)



        # (4) Global full-attention output
        # Z_full: (B, H, N, d)
        Z_full = torch.matmul(A_full, V)

        gate_logits = self.sdpa_gate_proj(Z_input)
        # gate: (B, H, N, 1)，广播到每个 head 的所有通道维度
        gate = torch.sigmoid(gate_logits).permute(0, 2, 1).unsqueeze(-1)
        # 对 Z_full 做门控: (B, H, N, d)
        Z_full = Z_full * gate

        # Split global output into two manifolds (still in head space)
        # L_full: (B, H, M, d), X_full: (B, H, C, d)
        L_full = Z_full[:, :, :M, :]
        X_full = Z_full[:, :, M:, :]

        # (5) Split A_full into four blocks
        # A_LL: (B,H,M,M), A_LX:(B,H,M,C), A_XL:(B,H,C,M), A_XX:(B,H,C,C)
        A_LL = A_full[:, :, :M, :M]
        A_LX = A_full[:, :, :M, M:]
        A_XL = A_full[:, :, M:, :M]
        A_XX = A_full[:, :, M:, M:]




        # 在 forward 中替换原来的标量乘法：
        # ===== 用可学习矩阵变换注意力头 =====
        # A_blocks shape: (B, H, N1, N2)
        # A_LL = A_LL.permute(0, 2, 3, 1)  # (B, M, M, H)
        # A_LL = torch.matmul(A_LL, self.attn_matrix_LL)  # (B, M, M, H)
        # A_LL = A_LL.permute(0, 3, 1, 2)  # (B, H, M, M)

        # A_LX = A_LX.permute(0, 2, 3, 1)  # (B, M, C, H)
        # A_LX = torch.matmul(A_LX, self.attn_matrix_LX)  # (B, M, C, H)
        # A_LX = A_LX.permute(0, 3, 1, 2)  # (B, H, M, C)

        # A_XL = A_XL.permute(0, 2, 3, 1)  # (B, C, M, H)
        # A_XL = torch.matmul(A_XL, self.attn_matrix_XL)  # (B, C, M, H)
        # A_XL = A_XL.permute(0, 3, 1, 2)  # (B, H, C, M)

        # A_XX = A_XX.permute(0, 2, 3, 1)  # (B, C, C, H)
        # A_XX = torch.matmul(A_XX, self.attn_matrix_XX)  # (B, C, C, H)
        # A_XX = A_XX.permute(0, 3, 1, 2)  # (B, H, C, C)


        # Split V into two manifolds
        # V_L: (B,H,M,d), V_X: (B,H,C,d)
        # V_L = V[:, :, :M, :]
        # V_X = V[:, :, M:, :]
        V_LL = self._split_heads(self.v_proj_LL(L))  # (B,H,N,d)
        V_LX = self._split_heads(self.v_proj_LX(Xt))  # (B,H,N,d)
        V_XL = self._split_heads(self.v_proj_XL(L))  # (B,H,N,d)
        V_XX = self._split_heads(self.v_proj_XX(Xt))  # (B,H,N,d)

        # Renormalize each block (row-wise)
        A_LL = renorm_block(A_LL)
        A_LX = renorm_block(A_LX)
        A_XL = renorm_block(A_XL)
        A_XX = renorm_block(A_XX)

        # Optional sharpening (no extra softmax)
        if self.sharpen_type == "topk":
            A_LL = topk_sharpen(A_LL, self.topk_ratio); A_LX = topk_sharpen(A_LX, self.topk_ratio)
            A_XL = topk_sharpen(A_XL, self.topk_ratio); A_XX = topk_sharpen(A_XX, self.topk_ratio)
        elif self.sharpen_type == "poly":
            effective_alpha = F.softplus(self.alpha)
            A_LL = poly_sharpen(A_LL, effective_alpha); A_LX = poly_sharpen(A_LX, effective_alpha)
            A_XL = poly_sharpen(A_XL, effective_alpha); A_XX = poly_sharpen(A_XX, effective_alpha)

        A_LL = self.dropout(A_LL); A_LX = self.dropout(A_LX)
        A_XL = self.dropout(A_XL); A_XX = self.dropout(A_XX)

        # (6) Cross-block AV for four blocks (still in head space)
        # L_ll, L_lx: (B,H,M,d); X_xx, X_xl: (B,H,C,d)
        # L_ll = torch.matmul(A_LL, V_L)  # L -> L
        # L_lx = torch.matmul(A_LX, V_X)  # L -> X
        # X_xx = torch.matmul(A_XX, V_X)  # X -> X
        # X_xl = torch.matmul(A_XL, V_L)  # X -> L
        L_ll = torch.matmul(A_LL, V_LL)  # L -> L
        L_lx = torch.matmul(A_LX, V_LX)  # L -> X
        X_xx = torch.matmul(A_XX, V_XX)  # X -> X
        X_xl = torch.matmul(A_XL, V_XL)  # X -> L

        # (7) Merge heads to token space for each block
        # (B,M,D) and (B,C,D)
        L_ll_tok = self._merge_heads(L_ll)  # (B, M, D)
        L_lx_tok = self._merge_heads(L_lx)  # (B, M, D)
        X_xx_tok = self._merge_heads(X_xx)  # (B, C, D)
        X_xl_tok = self._merge_heads(X_xl)  # (B, C, D)

        # (8) Block-wise linear projections
        # Each block has its own W
        L_ll_tok = self.proj_L_LL(L_ll_tok)  # (B, M, D)
        L_lx_tok = self.proj_L_LX(L_lx_tok)  # (B, M, D)
        X_xx_tok = self.proj_X_XX(X_xx_tok)  # (B, C, D)
        X_xl_tok = self.proj_X_XL(X_xl_tok)  # (B, C, D)

        # (9) Merge full-attn stream to token space
        # L_full_tok, X_full_tok: (B,*,D)
        L_full_tok = self._merge_heads(L_full)  # (B, M, D)
        X_full_tok = self._merge_heads(X_full)  # (B, C, D)
    
        # Manifold-level updates: full + all projected blocks
        # L_upd, X_upd: (B,*,D)
        # L_upd = self.out_L(L_full_tok + L_ll_tok + L_lx_tok)  # (B, M, D)
        # X_upd = self.out_X(X_full_tok + X_xx_tok + X_xl_tok)  # (B, C, D)

        L_upd = self.out_L(L_ll_tok + L_lx_tok)  # (B, M, D)
        X_upd = self.out_X(X_xx_tok + X_xl_tok)  # (B, C, D)

        # (10) Latent -> Xt fusion via gate
        gL = L_upd.mean(dim=1, keepdim=True)  # (B,1,D)
        if self.use_gate:
            gL_bc = gL.expand(-1, C, -1)  # (B,C,D)
            w = torch.sigmoid(self.gate(torch.cat([X_upd, gL_bc], dim=-1)))  # (B,C,1)
            X_fused = w * X_upd + (1.0 - w) * gL_bc
        else:
            X_fused = X_upd + gL.expand(-1, C, -1)

        # Residual + LayerNorm on each manifold
        L_out = self.norm_L(L + self.dropout(L_upd))      # (B, M, D)
        X_out = self.norm_X(Xt + self.dropout(X_fused))   # (B, C, D)

        # (11) Build Z_full_tok (global stream) and Z_blk_tok (block stream)
        # Z_full_tok: (B, N, D)
        # Z_full_tok = torch.cat([L_full_tok, X_full_tok], dim=1)

        Z_full_tok = self._merge_heads(Z_full)

        # Z_blk_tok: sum of four projected block outputs
        # L_blk_tok, X_blk_tok: (B,*,D)
        L_blk_tok = L_ll_tok + L_lx_tok       # (B, M, D)
        X_blk_tok = X_xx_tok + X_xl_tok       # (B, C, D)


        Z_blk_tok = torch.cat([L_blk_tok, X_blk_tok], dim=1)  # (B, N, D)
        
        # Unified stream inside this stage = global + block
        # (You will add [L_out, X_out] to this outside the stage.)
        Z_next = self.attn_weight*self.dropout(Z_full_tok) + (1-self.attn_weight)*self.dropout(Z_blk_tok)  # (B, N, D)

        # Z_next = self.z_norm(Z_next)
        # Z_next = (Z_full_tok) + (Z_blk_tok)  # (B, N, D)
        Z_next = self.z_norm(Z_input + Z_next)

        return L_out, X_out, Z_next
        # return L_out, X_out, Z_full_tok



# ---------------- multi-stage head ----------------

class DMHImputerCross(nn.Module):
    """
    Multi-stage cross-block DMH head.

    Each stage returns unified Z_next; we re-split it for the next stage.
    """

    def __init__(
        self,
        latent_raw_dim: int,
        xt_raw_dim: int,
        model_dim: int,
        num_heads: int,
        num_stages: int = 2,
        dropout: float = 0.0,
        use_gate: bool = True,
        sharpen_type: str = "none",
        poly_alpha_init: float = 0.2,
        topk_ratio:float = 0.5
    ):
        super().__init__()
        self.latent_align = nn.Linear(latent_raw_dim, model_dim) if latent_raw_dim != model_dim else nn.Identity()
        self.xt_align = nn.Linear(xt_raw_dim, model_dim) if xt_raw_dim != model_dim else nn.Identity()

        self.stages = nn.ModuleList([
            DMHStageCross(
                model_dim=model_dim,
                num_heads=num_heads,
                dropout=dropout,
                use_gate=use_gate,
                sharpen_type=sharpen_type,
                poly_alpha_init=poly_alpha_init,
                topk_ratio=topk_ratio
            ) for _ in range(num_stages)
        ])

    def forward(self, latent_raw, xt_raw):
        L = self.latent_align(latent_raw)  # (B,M,D)
        Xt = self.xt_align(xt_raw)         # (B,C,D)
        M = L.size(1)
        Z_next = torch.cat([L, Xt], dim=1)

        for stage in self.stages:
            L, Xt, Z_next = stage(L, Xt)   # Z_next is unified output
            # re-split for the next stage
            L = Z_next[:, :M, :]
            Xt = Z_next[:, M:, :]

         # (1) Align L_out and Xt_out back to token space
        Z_from_branches = torch.cat([L, Xt], dim=1)  # (B, M+C, D)

        # (2) THREE-way aligned summation -> one unified variable
        Z_unified = Z_from_branches + Z_next  # (B, M+C, D)
        # Z_unified = Z_next

        # Final: ONLY ONE unified output for downstream computation
        return Z_unified  # (B, M+C, D)


class VarSpaceEncoder(nn.Module):
    """
    数据驱动的变量空间编码：
      - 静态: 每个变量一个 id embedding
      - 动态: 利用 (mean, std, missing_ratio) 生成每个变量在当前 batch 下的动态偏移
    """
    def __init__(
        self,
        num_vars: int,   # C
        d_model: int,    # D
        use_dynamic: bool = True,
        hidden_dim: int = None,
    ):
        super().__init__()
        self.C = num_vars
        self.D = d_model
        self.use_dynamic = use_dynamic

        # 静态变量 id embedding
        self.id_emb = nn.Embedding(num_vars, d_model)

        if hidden_dim is None:
            hidden_dim = d_model // 2

        if use_dynamic:
            # 输入特征: [mean, std, missing_ratio] 共 3 维
            self.mlp = nn.Sequential(
                nn.Linear(3, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, d_model),
            )
        else:
            self.mlp = None

        self.norm = nn.LayerNorm(d_model)

    def forward(self, x_hist: torch.Tensor, m_hist: torch.Tensor) -> torch.Tensor:
        """
        x_hist: (B, L, C)
        m_hist: (B, L, C), 1=missing, 0=observed

        返回:
            var_pos: (B, C, D)
        """
        B, L, C = x_hist.shape
        assert C == self.C

        # ---- 静态部分：每个变量一个 id embedding ----
        idx = torch.arange(C, device=x_hist.device)         # (C,)
        static_emb = self.id_emb(idx)                       # (C, D)
        static_emb = static_emb.unsqueeze(0).expand(B, C, self.D)  # (B, C, D)

        if not self.use_dynamic:
            return self.norm(static_emb)

        # ---- 动态部分：用当前窗口的统计量生成偏移 ----
        # 只在观测位置上统计
        observed = (m_hist == 0).float()         # 1=observed, 0=missing
        # 避免除 0
        eps = 1e-6
        obs_count = observed.sum(dim=1) + eps    # (B, C)

        # mean, std
        sum_x = (x_hist * observed).sum(dim=1)   # (B, C)
        mean = sum_x / obs_count                 # (B, C)

        sq = ((x_hist - mean.unsqueeze(1)) ** 2) * observed
        var = sq.sum(dim=1) / obs_count          # (B, C)
        std = torch.sqrt(var + eps)              # (B, C)

        # missing ratio
        miss_ratio = 1.0 - (obs_count / L)       # (B, C)

        # 拼成特征: (B, C, 3)
        stats = torch.stack([mean, std, miss_ratio], dim=-1)

        # 动态偏移: (B, C, D)
        dyn_emb = self.mlp(stats)

        var_pos = static_emb + dyn_emb           # (B, C, D)
        var_pos = self.norm(var_pos)
        return var_pos



class TemporalLatentWriter(nn.Module):
    """
    Two latent formats:

    (A) Time-latent (per-variable temporal compression):
        For each variable i, compress its  history tokens -> M latents.
        Output: Z_time in [B, M, C, D]

    (B) Var-latent (joint time×variable compression):
        Compress history*C history tokens -> M latents.
        Output: Z_var in [B, M, D]

    Input:
        x_full: [B, T, C]
        m_full: [B, T, C]  (1=missing, 0=observed)
        t_index: 0-based current time

    Output:
        Z_time: [B, M, C, D]
        Z_var : [B, M, D]
        X_t_var: [B, C, D]  current-time variable tokens
    """
    def __init__(
        self,
        in_dim: int,       # C
        model_dim: int,    # D
        num_latents: int,  # M
        cross_heads: int = 4,
        max_len: int = 2048,
        dropout: float = 0.1,
        share_cross_attn: bool = False,
        build_type='time' #time or var
    ):
        super().__init__()
        self.C = in_dim
        self.D = model_dim
        self.M = num_latents
        self.share_cross_attn = share_cross_attn

        # ---------- embeddings ----------
        # scalar -> D  (for per-variable tokens)
        self.feat_emb_scalar = nn.Linear(1, model_dim)
        self.mask_emb_scalar = nn.Linear(1, model_dim, bias=False)
        
        self.xt_proj = nn.Linear(self.C, self.D, bias=False)
        self.mt_proj = nn.Linear(self.C, self.D, bias=False)


        # time / var embeddings
        self.time_emb = nn.Embedding(max_len, model_dim)
        self.var_emb  = nn.Embedding(in_dim, model_dim)
        self.var_space_enc = VarSpaceEncoder(
            num_vars=in_dim,
            d_model=model_dim,
            use_dynamic=True,    # 你可以先开着
        )
        self.max_len = max_len

        # latent seeds (shared across variables)
        self.latent_seeds = nn.Parameter(torch.randn(num_latents, model_dim) * 0.02)

        # cross-attn for time-latent and var-latent
        self.cross_time = nn.MultiheadAttention(model_dim, cross_heads, dropout=dropout, batch_first=True)

        self.self_attn = nn.MultiheadAttention(
            model_dim, cross_heads, dropout=dropout, batch_first=True
        ) 

        self.cross_var  = self.cross_time if share_cross_attn else \
                          nn.MultiheadAttention(model_dim, cross_heads, dropout=dropout, batch_first=True)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(model_dim)

        self.build_type = build_type

        # ---------- NEW: history-aligned Fourier encoding (init-heavy) ----------
        fourier_K_max=8
        self.fourier_K_max = int(fourier_K_max)

        # (1) precompute freqs: (K_max,)
        freqs = torch.arange(1, self.fourier_K_max + 1, dtype=torch.float32)  # 1..K_max
        self.register_buffer("fourier_freqs", freqs, persistent=False)

        # (2) precompute s index: (max_len,)
        s = torch.arange(self.max_len, dtype=torch.float32)  # 0..max_len-1
        self.register_buffer("hist_s_idx", s, persistent=False)

        # (3) precompute numerator base: base_num[s,k] = 2π * s * k  -> (max_len, K_max)
        base_num = 2.0 * math.pi * s[:, None] * freqs[None, :]  # (max_len, K_max)
        self.register_buffer("fourier_base_num", base_num, persistent=False)

        # (4) precompute invL and (L-1)/L for L=0..max_len
        invL = torch.zeros(self.max_len + 1, dtype=torch.float32)
        lm1_over_L = torch.zeros(self.max_len + 1, dtype=torch.float32)
        for L in range(1, self.max_len + 1):
            invL[L] = 1.0 / float(L)
            lm1_over_L[L] = float(L - 1) / float(L)
        self.register_buffer("invL_table", invL, persistent=False)
        self.register_buffer("lm1_over_L_table", lm1_over_L, persistent=False)

        # (5) precompute K_eff(L) = min(K_max, floor(log2(max(L,2)))+1) for L=0..max_len
        K_eff_table = torch.zeros(self.max_len + 1, dtype=torch.long)
        for L in range(self.max_len + 1):
            if L <= 0:
                K_eff_table[L] = 0
            else:
                K_eff_table[L] = min(self.fourier_K_max, int(math.log2(max(L, 2))) + 1)
        self.register_buffer("K_eff_table", K_eff_table, persistent=False)

        # (6) fixed projection dim: (2*K_max -> D)
        self.fourier_proj = nn.Linear(2 * self.fourier_K_max, model_dim, bias=False)


    # ============================================================
    # History-aligned time encoding: τ(L) -> Fourier -> (L,D)
    # Most things precomputed in init; forward only does sin/cos + linear.
    # ============================================================
    def _hist_time_encoding(self, L: int, device: torch.device) -> torch.Tensor:
        """
        Return e_time_hist: (L, D)

        τ_s = (s-(L-1))/L in [-1,0], boundary anchored at 0.
        angles(s,k) = 2π*k*τ_s = (2π*s*k)/L - 2π*k*(L-1)/L
        """
        if L <= 0:
            return torch.zeros(0, self.D, device=device)
        if L > self.max_len:
            raise ValueError(f"L={L} exceeds max_len={self.max_len}. Increase max_len.")

        enc = 0.0


        K_eff = int(self.K_eff_table[L].item())  # O(1) lookup
        if K_eff > 0:
            freqs = self.fourier_freqs[:K_eff].to(device)  # (K_eff)
            # angles = base_num/L - 2π*k*(L-1)/L
            angles = self.fourier_base_num[:L, :K_eff].to(device) * self.invL_table[L].to(device)  # (L,K_eff)
            shift = (2.0 * math.pi * freqs) * self.lm1_over_L_table[L].to(device)                  # (K_eff,)
            angles = angles - shift[None, :]  # (L,K_eff)

            feat_small = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)  # (L,2*K_eff)

            # pad to (L,2*K_max) for fixed Linear input dim
            feat = torch.zeros(L, 2 * self.fourier_K_max, device=device, dtype=feat_small.dtype)
            feat[:, : 2 * K_eff] = feat_small

            enc = enc + self.fourier_proj(feat)  # (L,D)

        return enc  # (L,D)

   # ============================================================
    # Function 1: Time-latent  (history -> M)  per variable
    # Output: [B, M, C, D]
    # ============================================================
    def build_time_latent(
    self,
    x_hist: torch.Tensor,   # [B, history, C]
    m_hist: torch.Tensor,   # [B, history, C]
    pos_hist: torch.Tensor  # [history]
) -> torch.Tensor:
        """
        Time-latent 版本：

        Step 1: 构建 4D token: H_4d = [B, L, C, D]
        Step 2: 在变量维 C 上聚合 -> H_time = [B, L, D]
        Step 3: 用 M 个 learnable seeds 对 H_time 做 cross-attn -> Z_time = [B, M, D]

        返回:
            Z_time: [B, M, D]  # 时间视角的 latent token 序列
        """
        B, L, C = x_hist.shape
        if L == 0:
            # 没有历史，直接返回 seeds（每个样本一份）
            return self.latent_seeds.unsqueeze(0).expand(B, self.M, self.D)  # [B,M,D]

        # ---------- Step 1: 先构建 4D token [B, L, C, D] ----------
        xh = x_hist.unsqueeze(-1)  # [B, L, C, 1]
        mh = m_hist.unsqueeze(-1)  # [B, L, C, 1]

        # 标量值 + mask 嵌入
        E_var = self.feat_emb_scalar(xh) + self.mask_emb_scalar(mh)  # [B, L, C, D]

        # 时间位置编码
        e_time = self.time_emb(pos_hist).view(1, L, 1, self.D)       # [1, L, 1, D]

        # 变量位置编码
        var_idx = torch.arange(C, device=x_hist.device)
        e_var = self.var_emb(var_idx).view(1, 1, C, self.D)          # [1, 1, C, D]

        H_4d = E_var + e_time + e_var                                # [B, L, C, D]
        H_4d = self.dropout(self.norm(H_4d))

        # ---------- Step 2: 沿 C 聚合 -> 每个时间步一个 token ----------
        # 这里用简单 mean，你要“更科研一点”可以换成加权 sum / 小 MLP 聚合
        H_time = H_4d.mean(dim=2)    # [B, L, D]

        # ---------- Step 3: cross-attention seeds ↔ 时间序列 ----------
        # seeds: [B, M, D]
        seeds = self.latent_seeds.unsqueeze(0).expand(B, self.M, self.D)

        # Q = seeds, K/V = H_time
        Z_time, _ = self.cross_time(
            query=seeds,       # [B, M, D]
            key=H_time,        # [B, L, D]
            value=H_time,      # [B, L, D]
            need_weights=False
        )                      # -> [B, M, D]

        Z_time = self.dropout(self.norm(Z_time))  # 稍微稳一点

        return Z_time  # [B, M, D]


    # ============================================================
    # Function 2: Var-latent  ((history)*C -> M)  joint tokens
    # Output: [B, M, D]
    # ============================================================
    def build_var_latent(
        self,
        x_hist: torch.Tensor,   # [B, history, C]
        m_hist: torch.Tensor,   # [B, history, C]
        pos_hist: torch.Tensor  # [history]
    ) -> torch.Tensor:
        """
        Flatten time×variable tokens -> (L*C) tokens, compress to M via cross-attn.

        Returns:
            Z_var: [B, M, D]
        """
        B, L, C = x_hist.shape
        if L == 0:
            return self.latent_seeds.unsqueeze(0).expand(B, self.M, self.D)

        # 4D tokens first: [B,L,C,D]
        xh = x_hist.unsqueeze(-1)  # [B,L,C,1]
        mh = m_hist.unsqueeze(-1)

        E_var = self.feat_emb_scalar(xh) + self.mask_emb_scalar(mh)  # [B,L,C,D]

        e_time = self.time_emb(pos_hist).view(1, L, 1, self.D)
        # var_idx = torch.arange(C, device=x_hist.device)
        # e_var = self.var_emb(var_idx).view(1, 1, C, self.D)
        
        # ★ 新：变量空间编码（带动态信息）★
        # var_pos: (B, C, D)
        var_pos = self.var_space_enc(x_hist, m_hist)                  # [B,C,D]
        e_var = var_pos.view(B, 1, C, self.D).expand(B, L, C, self.D) # [B,L,C,D]



        H_var_4d = E_var + e_time + e_var
        H_var_4d = self.dropout(self.norm(H_var_4d))

        # flatten (L,C) -> N=L*C: [B, L*C, D]
        H_var = H_var_4d.reshape(B, L * C, self.D)

        Z0 = self.latent_seeds.unsqueeze(0).expand(B, self.M, self.D)
        Z_var, _ = self.cross_var(Z0, H_var, H_var, need_weights=False)

        return Z_var  # [B,M,D]
    # def build_var_latent(
    #     self,
    #     x_hist: torch.Tensor,   # [B, history, C]
    #     m_hist: torch.Tensor,   # [B, history, C]
    #     pos_hist: torch.Tensor  # [history]
    # ) -> torch.Tensor:
    #     """
    #     Informer-style:
    #       1) 先构建 4D token: [B, L, C, D]
    #       2) 在变量维 C 上做聚合，得到 per-time token: [B, L, D]
    #       3) 在时间维 L 上做 self-attn（不压缩时间）
    #     Returns:
    #         Z_var: [B, L, D]
    #     """
    #     B, L, C = x_hist.shape
    #     if L == 0:
    #         # 没有历史，就返回全 0（或者用可学习参数也可以）
    #         return torch.zeros(B, 0, self.D, device=x_hist.device)

    #     # 1) 4D tokens: [B,L,C,D]
    #     xh = x_hist.unsqueeze(-1)  # [B,L,C,1]
    #     mh = m_hist.unsqueeze(-1)
    #     E_var = self.feat_emb_scalar(xh) + self.mask_emb_scalar(mh)  # [B,L,C,D]

    #     e_time = self.time_emb(pos_hist).view(1, L, 1, self.D)       # [1,L,1,D]
    #     var_idx = torch.arange(C, device=x_hist.device)
    #     e_var = self.var_emb(var_idx).view(1, 1, C, self.D)          # [1,1,C,D]

    #     H_var_4d = E_var + e_time + e_var                            # [B,L,C,D]
    #     H_var_4d = self.dropout(self.norm(H_var_4d))

    #     # >>> 关键修改 1：在变量维上聚合，得到“每个时间步一个 token”
    #     #    这里用简单的 mean，你也可以换成加权/线性聚合
    #     H_time = H_var_4d.mean(dim=2)        # [B, L, D]   # <<< 新增：保留时间 L，不再 flatten 为 L*C

    #     # >>> 关键修改 2：沿时间维做 self-attn（Informer 风格）
    #     Z_var, _ = self.self_attn(H_time, H_time, H_time, need_weights=False)  # [B,L,D]

    #     # 再做一次 norm/dropout（可选）
    #     Z_var = self.dropout(self.norm(Z_var))  # [B,L,D]   # <<< 新增：时序 token 的标准化

    #     return Z_var  # [B,L,D]  （不再是 [B,M,D]）

    # ============================================================
    # Forward: output both latents + current tokens
    # ============================================================
    def forward(
        self,
        x_full: torch.Tensor,          # [B,T,C]
        m_full: torch.Tensor,          # [B,T,C]
        t_index: Optional[int] = None  # 0-based
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        B, T, C = x_full.shape
        assert C == self.C
        if t_index is None:
            t_index = T - 1
            t_index = T / 2
        assert 0 <= t_index < T

        # history 
        x_hist = x_full[:, :t_index, :]  # [B,L,C], L=history
        m_hist = m_full[:, :t_index, :]
        L = t_index

        pos_hist = torch.arange(L, device=x_full.device)  # [L]

        # ---- two latents ----
        if self.build_type == 'time':
            # Z_time = self.build_time_latent(x_hist, m_hist, pos_hist)  # [B,M,C,D]
            Z = self.build_time_latent(x_hist, m_hist, pos_hist)  # [B,M,C,D]
        elif self.build_type == 'var':
            # Z_var  = self.build_var_latent(x_hist, m_hist, pos_hist)   # [B,M,D]
            Z = self.build_var_latent(x_hist, m_hist, pos_hist)   # [B,M,D]
        else:
            raise ValueError("Building Latent Type {self.build_type} 不存在")

        # ---- current-time variable tokens X_t_var: [B,C,D] ----
            # ---- 构造整个窗口的 variable tokens: X_all_var: [B, seq_len, C, D] ----
        # x_full: [B, seq_len, C]
        # m_full: [B, seq_len, C]

        xt = x_full.unsqueeze(-1)  # [B, seq_len, C, 1]
        mt = m_full.unsqueeze(-1)  # [B, seq_len, C, 1]

        # 标量特征 + mask 的嵌入
        E_all_var = self.feat_emb_scalar(xt) + self.mask_emb_scalar(mt)  # [B, seq_len, C, D]

        # 时间位置编码：用窗口内部 0...(seq_len-1) 做局部时间
        t_idx = torch.arange(T, device=x_full.device)          # [seq_len]
        e_time_all = self.time_emb(t_idx)                      # [seq_len, D]
        e_time_all = e_time_all.view(1, T, 1, self.D)          # [1, seq_len, 1, D]

        # 变量位置编码：每个通道一个 embedding
        var_idx = torch.arange(C, device=x_full.device)
        e_var = self.var_emb(var_idx).view(1, 1, C, self.D)    # [1, 1, C, D]

        # 整个窗口的 4D token
        X_all_var = E_all_var + e_time_all + e_var             # [B, seq_len, C, D]

        # if self.build_type == 'var':
        X_all_var = self.dropout(self.norm(X_all_var))     # [B, seq_len, C, D]

        # ---- 按时间拆：前面做 latent，后面做 impute ----
        impute_len = T - t_index          # 你在 __init__ 里保存这个
        hist_len = t_index

        # 待插补部分 4D -> 3D，沿用名字 X_t_var
        X_t_4d = X_all_var[:, hist_len:, :, :]                 # [B, impute_len, C, D]


        X_t_var = X_t_4d.reshape(B, impute_len * C, self.D)    # [B, impute_len*C, D]


        # Xt = self.dropout(self.norm(Xt))
        # X_t_var = self.var_proj(X_t_4d.sreshape(B, impute_len, C* self.D) )
        # return Z_time, Z_var, X_t_var
        return Z, X_t_var
    


class CBiTAttention(nn.Module):
    """
    CBiT attention block:
      Input : x (B, T, C), m (B, T, C)
      Internals:
        - TemporalLatentWriter: (B,T,C) -> latent (B,M,D), Xt_embed (B,C,D)
        - DMHImputerCross: (B,M,D)+(B,C,D) -> Z_next (B,M+C,D)
        - Readout: 从 Z_next 中提取 Xt 区域，并线性映射
      Output: Xt_feat (B, C, D)
    """
    def __init__(
        self,
        in_dim: int,         # 原始通道 C_in
        d_model: int,        # 模型隐层维度 D
        num_heads: int,
        build_latent_type: str = "time",  # "time" or "var"
        num_latents: int = 16,
        latents_cross_heads: int = 4,
        max_len: int = 2048,
        dropout: float = 0.1,
        sharpen_type: str = "none",
        num_stages: int = 2,
        impute_len: int = 16,
        poly_alpha_init: float = 0.2,
        topk_ratio:float = 0.5
    ):
        super().__init__()

        self.impute_len = impute_len
        # 1) latent 编码器: (B,T,C) -> (B,M,D) + Xt_raw(B,C,D)
        self.latent_writer = TemporalLatentWriter(
            in_dim=in_dim,
            model_dim=d_model,
            num_latents=num_latents,
            cross_heads=latents_cross_heads,
            max_len=max_len,
            dropout=dropout,
            share_cross_attn=False,
            build_type=build_latent_type,
        )

        # 2) CBiT cross-block 注意力头
        self.cbit_head = DMHImputerCross(
            latent_raw_dim=d_model,
            xt_raw_dim=d_model,
            model_dim=d_model,
            num_heads=num_heads,
            num_stages=num_stages,  # 用上传进来的 num_stages
            dropout=dropout,
            use_gate=True,
            sharpen_type=sharpen_type,
            poly_alpha_init=poly_alpha_init,
            topk_ratio=topk_ratio
        )

        # 3) 从 unified Z_next 中读出 Xt 特征的线性层
        self.xt_readout = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, C)
        m: (B, T, C)

        return:
          Xt_feat: (B, C, D)
        """
        B, T, C = x.shape
        t_index = T - self.impute_len  # 当前时刻索引

        # latent_writer 内部会做 (B,T,C) -> (B,T,C,D) -> (B,M,D) 之类的映射
        latent_tokens, Xt_embed = self.latent_writer(x, m, t_index)
        # latent_tokens: (B, M, D)
        # Xt_embed     : (B, C, D)


        # CBiT cross-block 注意力
        Z_next = self.cbit_head(latent_tokens, Xt_embed)  # (B, M+C, D)

        # 提取 Xt 区域，再做一次线性变换作为最终 Xt 特征
        Xt_next_raw = Z_next[:, -C:, :]           # (B, C, D)
        Xt_feat = self.xt_readout(Xt_next_raw)    # (B, C, D)

        return Xt_feat


class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1, activation="gelu"):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.dropout = nn.Dropout(dropout)

        # -----------------------------
        # Activation choices
        # -----------------------------
        if activation == "relu":
            self.act = nn.ReLU()
        elif activation == "gelu":
            self.act = nn.GELU()
        elif activation == "silu":
            self.act = nn.SiLU()
        elif activation == "mish":
            self.act = nn.Mish()
        elif activation == "leaky_relu":
            self.act = nn.LeakyReLU(0.1)
        elif activation == "swiglu":
            # SwiGLU: (XW1) ⊗ SiLU(XW2)
            self.act = "swiglu"
            self.fc1 = nn.Linear(d_model, d_ff * 2)  # split later
            self.fc2 = nn.Linear(d_ff, d_model)
            return
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        # -----------------------------
        # Standard FFN (no gating)
        # -----------------------------
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        """
        x: (B, C, D)  or  (B, T, D)
        """
        if getattr(self, "act") == "swiglu":
            # SwiGLU
            xw = self.fc1(x)              # (B, *, 2*d_ff)
            x1, x2 = xw.chunk(2, dim=-1)  # (B, *, d_ff)
            x = x1 * F.silu(x2)           # gated unit
            x = self.fc2(self.dropout(x))
            return x

        # Standard FFN
        return self.fc2(self.dropout(self.act(self.fc1(x))))


class CBiTEncoderLayer(nn.Module):
    """
    单层 Encoder：
      x,m -> CBiTAttention -> Xt_feat_attn
      Xt_feat_attn -> FFN -> Dropout -> Add -> LN
    这里的“状态”就是 Xt_feat (B, C, D)
    """
    def __init__(
        self,
        in_dim: int,          # 原始通道维度 C_in
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = "gelu",  # "gelu" / "relu" / "silu" 等
        build_latent_type: str = "time",
        num_latents: int = 16,
        latents_cross_heads: int = 4,
        max_len: int = 2048,
        sharpen_type: str = "none",
        num_stages: int = 2,
        impute_len: int = 16,
        poly_alpha_init: float = 0.2,
        topk_ratio:float = 0.5
    ):
        super().__init__()

        # Attention 部分：把 (B,T,C) + (B,T,C) 压成当前时刻的 (B,C,D)
        self.attn = CBiTAttention(
            in_dim=in_dim,
            d_model=d_model,
            num_heads=n_heads,
            build_latent_type=build_latent_type,
            num_latents=num_latents,
            latents_cross_heads=latents_cross_heads,
            max_len=max_len,
            dropout=dropout,
            sharpen_type=sharpen_type,
            num_stages=num_stages,
            impute_len=impute_len,
            poly_alpha_init=poly_alpha_init,
            topk_ratio=topk_ratio
        )

        # FFN 部分：标准 Position-wise FFN
        self.ffn = PositionwiseFFN(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=activation,
        )

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, C)
        m: (B, T, C)

        return:
          Xt_feat: (B, C, D)
        """
        # 1) Attention：从全时序构建当前时刻的变量特征
        Xt_feat = self.attn(x, m)          # (B, C, D)

        # 2) FFN + 残差 + LN（残差连接在 Xt_feat 自身上）
        ffn_out = self.ffn(Xt_feat)        # (B, C, D)
        Xt_feat = self.norm(Xt_feat + self.dropout(ffn_out))

        return Xt_feat


class CBiT(nn.Module):
    """
    整体 CBiT 模型：
      输入  : x (B,T,C), m (B,T,C)
      流程  :
        for each EncoderLayer:
            x,m -> Attention -> Xt_feat_l (B,C,D)
                -> FFN+LN(refine Xt_feat_l)
        最后一层 Xt_feat_L -> Linear(D,1) -> (B,C)

      输出  : y (B, C)
    """
    def __init__(
        self,
        in_dim: int,
        num_layers: int=4,
        d_model: int=1024,
        n_heads: int=8,
        d_ff: int=2048,
        dropout: float = 0.1,
        activation: str = "gelu",
        build_latent_type: str = "time",
        num_latents: int = 16,
        latents_cross_heads: int = 4,
        max_len: int = 2048,
        sharpen_type: str = "none",
        num_stages: int = 2,
        impute_len: int = 16,
        poly_alpha_init: float = 0.2,
        topk_ratio:float = 0.5
    ):
        super().__init__()

        # 堆叠多层 EncoderLayer（每层内部包含 CBiTAttention + FFN）
        self.layers = nn.ModuleList([
            CBiTEncoderLayer(
                in_dim=in_dim,
                d_model=d_model,
                n_heads=n_heads,
                d_ff=d_ff,
                dropout=dropout,
                activation=activation,
                build_latent_type=build_latent_type,
                num_latents=num_latents,
                latents_cross_heads=latents_cross_heads,
                max_len=max_len,
                sharpen_type=sharpen_type,
                num_stages=num_stages,
                poly_alpha_init=poly_alpha_init,
                topk_ratio=topk_ratio
            )
            for _ in range(num_layers)
        ])

        # 最终输出头：每个变量从 D -> impute_len
        self.out_proj = nn.Linear(d_model, impute_len)
        # self.test = nn.Linear(d_model, 48)

    def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
        """
        x: (B, T, C)
        m: (B, T, C)

        return:
          y: (B, C)  # 每个变量一个插值/预测值
        """
        B, T, C = x.shape
        Xt_feat = None

        # 每一层都从原始序列 x,m 中抽取当前时刻的特征，并用 FFN 做 refinement
        for layer in self.layers:
            Xt_feat = layer(x, m)      # (B, C, D)

        # 最后一次 Xt_feat 作为整个 Encoder 的输出特征
        y = self.out_proj(Xt_feat)    # (B, C, impute_len)
        # y = y.squeeze(-1)             # (B, impute_len,C)
        y = y.transpose(1, 2)

        # y = self.test(Xt_feat).transpose(1,2)  # (B, 48, C)
        
        return y
