#!/usr/bin/env python3
"""
Student_v9: Student with Landmark-Aware Teacher Training

This module extends Student_v8 by using a landmark-aware teacher to train the student,
then applying landmark rescaling during inference for correct scale output.

Key Enhancement over Student_v8:
- Uses landmark-aware teacher for training (teacher handles landmark scaling internally)
- Student architecture remains unchanged (learns from better teacher)
- Applies landmark rescaling during inference/reconstruction
- Maintains full compatibility with Student_v8 infrastructure

Architecture:
- LandmarkAwareStudentForward_v9: Enhanced student forward with landmark-aware teacher
- Student_v9: Complete model with InverseEncoder + TimePrior (same as v8)
- Landmark rescaling applied at output stage for correct scaling

This allows the student to learn from a more robust teacher while maintaining 
correct coordinate scaling for downstream applications.
"""
from __future__ import annotations
import sys
sys.path.append("..")  # Adjust the path to import from the parent directory
import torch
import copy
import logging
from torch import nn

# Import teacher from current package (avoid naisr dependencies for now)
from .teacher import Teacher
from .base_net import BaseNet, TemporalExpansionConstraint, LandmarkRescaler, extract_mesh
from ..base import BaseFCBlock
from ..networks import Siren  # Changed import to Siren directly
import copy
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OMEGA_0 = 30  # Constant for Siren initialization




class StableTrajectoryEstimator(BaseNet):
    """
    TrajectoryEstimator: Student_v10 + Landmark-Aware Teacher Training

    核心创新：
    1. 使用landmark-aware teacher训练student
    2. Student网络架构保持不变（与v8相同）
    3. Teacher内部处理landmark scaling，提供更好的训练信号
    4. 推理时应用landmark rescaling确保正确的输出尺度
    5. 保持与Student_v8的完全兼容性
    
    训练优势：
    - Student从处理landmark scaling的teacher学习，获得更好的deformation表示
    - Teacher已经学会处理不同尺度的数据，提供更稳定的训练目标
    - Student学会预测在不同landmark scale下的合理deformation
    """
    
    def __init__(self, cfg):
        super().__init__(cfg)
        logger.info("Initializing TrajectoryEstimator...")

        # Validate configuration
        required_keys = ["CovariateNames", "Device", "teacher"]
        self.validate_cfg(cfg, required_keys)

        # Extract configuration parameters  
        self.covariate_names = cfg.get("CovariateNames", [])
        hidden_features = cfg.get("HiddenFeatures", 512)
        hidden_layers = cfg.get("HiddenLayers", 6)
        
        # Initialize attributes for compatibility with Student_v8
        self.dict_idx_cov = {i: cov for i, cov in enumerate(self.covariate_names)}
        self.dict_cov_idx = {cov: i for i, cov in enumerate(self.covariate_names)}

        # 🔥 Enhanced Teacher with Landmark Scale Support
        teacher_cfg = cfg["teacher"]
        
        # 检查是否启用landmark rescaling
        self.use_landmark_rescaling = cfg.get("UseLandmarkRescaling", True)
        
        if self.use_landmark_rescaling:
            # 使用landmark-aware teacher
            logger.info("🎯 Using Landmark-Aware Teacher for training student")
            
            # Landmark rescaling configuration
            landmark_cfg = cfg.get("LandmarkRescaling", {})
            self.reference_scale = landmark_cfg.get("reference_scale", 1.0)
            
            # 初始化landmark rescaler
            self.landmark_rescaler = LandmarkRescaler(self.reference_scale)
            logger.info(f"   Reference scale: {self.reference_scale}")
        else:
            logger.info("🔧 Using standard Teacher (no landmark scaling)")
            self.landmark_rescaler = None
        
        self.teacher = Teacher(teacher_cfg)
        self._freeze_teacher()
        self.sampling_range = 1.2 if self.in_dim == 3 else 2.0
        # === Fisher / 噪声学习相关开关 ===
        self.predict_fisher_uncertainty = cfg.get("PredictFisherUncertainty", True)
        self.learn_anisotropic_noise = cfg.get("LearnAnisotropicNoise", True)
        self.fisher_eps = float(cfg.get("FisherEps", 1e-2))
        self.time_prior_sigma = float(cfg.get("TimePriorSigma", 0.5))  # 先验~1.0年
        self.VBlendTau = float(cfg.get("VBlendTau", 0.2))

        # 作为回退：若不学噪声，就用同方差常数
        self.obs_noise_sigma2 = float(cfg.get("ObsNoiseSigma2", 1.0))
        #hidden_features = 256
        
        # === 噪声头（各向异性，预测每点 3x3 Cholesky 的 6 个自由参数）===
        # 输入与 cov_nets 相同（coords_encoded+cov），你也可只给 time 分支
        self.noise_net = BaseFCBlock(   
            in_features= self.in_dim+1,
            out_features=(self.in_features * (self.in_features + 1)) //2,
            hidden_features=256,
            num_hidden_layers=3,
            outermost_linear=True,
            nonlinearity='gelu',  # swish
        )
        self.noise_residual_net = BaseFCBlock(   
            in_features= self.in_features,
            out_features=(self.in_features * (self.in_features + 1)) //2,
            hidden_features=256,
            num_hidden_layers=3,
            outermost_linear=True,
            nonlinearity='gelu',  # swish
        )

        # self.noise_net = Siren(
        #     in_features=self.in_dim + 1,
        #     hidden_features=hidden_features,
        #     hidden_layers=hidden_layers,
        #     out_features=(self.in_features * (self.in_features+1))//2,         # [l11, l22, l33, l21, l31, l32]
        #     outermost_linear=True,
        #     first_omega_0=OMEGA_0,
        #     hidden_omega_0=OMEGA_0,
        #     zero_init_last_layer=False,
        #     is_first=False
        # )
        # self.noise_residual_net = Siren(
        #     in_features=self.in_dim,
        #     hidden_features=hidden_features,
        #     hidden_layers=hidden_layers,
        #     out_features=(self.in_features * (self.in_features+1))//2,         # [l11, l22, l33, l21, l31, l32]
        #     outermost_linear=True,
        #     first_omega_0=OMEGA_0,
        #     hidden_omega_0=OMEGA_0,
        #     zero_init_last_layer=False,
        #     is_first=False
        # )

        #2. 为每个协变量创建标准deformation网络（与Student_v8相同）
        #Student网络架构保持不变，但从landmark-aware teacher学习
        self.cov_nets = nn.ModuleDict({
            cov: Siren(
                in_features=self.in_features + 1,  # coords (encoded) + covariate (1)
                hidden_features=hidden_features,
                hidden_layers=hidden_layers,
                out_features=self.in_features,  # deformation (x, y, z)
                outermost_linear=True,
                first_omega_0=OMEGA_0,
                hidden_omega_0=OMEGA_0,
                zero_init_last_layer=False,
                is_first=False
            )
            for cov in self.covariate_names
        })

        self.geometry_residual_net = Siren(
            in_features=self.in_features,
            hidden_features=hidden_features,
            hidden_layers=hidden_layers,
            out_features=self.in_features,
            outermost_linear=True,
            first_omega_0=OMEGA_0,
            hidden_omega_0=OMEGA_0,
            zero_init_last_layer=False,
            is_first=False
        )


        # self.cov_nets = nn.ModuleDict({
        #     cov:  BaseFCBlock(   
        #     in_features=self.in_dim + 1,
        #     out_features=self.in_features,
        #     hidden_features=hidden_features,
        #     num_hidden_layers=hidden_layers,
        #     outermost_linear=True,
        #     nonlinearity='swish', #'gelu',  # swish
        # )
        #     for cov in self.covariate_names
        # })

        
        logger.info(f"✅ Initialized StableTrajectoryEstimator with {len(self.covariate_names)} networks")
        logger.info(f"   Covariates: {self.covariate_names}")
        logger.info(f"   Mode: Learning from landmark-aware teacher")


    @property
    def device(self):
        """Get the device of the model."""
        return next(self.parameters()).device


    def validate_cfg(self, cfg: dict, required_keys: list[str]):
        """Validates the configuration dictionary."""
        for key in required_keys:
            if key not in cfg:
                raise ValueError(f"Missing required configuration key: {key}")

    def _freeze_teacher(self):
        """Freezes the parameters of the teacher model to prevent updates during training."""
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.teacher.eval()
        

    def extract_template_mesh(self, N: int = 512, truncate: bool = False, scale: float = None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Extracts the template mesh points from the teacher model.

        Args:
            dim (int): Dimensionality of the mesh (default: 3).
            scale (float): Scale factor for the mesh (default: 3).
            N (int): Number of samples along each dimension (default: 512).

        Returns:
            torch.Tensor: Template mesh points.
        """
        if 'age' in self.covariate_names:
            scale = 0.72 
        else:
            scale = scale if scale is not None else self.sampling_range
        verts, faces, _, _ = extract_mesh(self.teacher, scale=scale, N=N, truncate=truncate)
        self.template_mesh = verts.cuda()
        self.template_face = faces 
        return verts, faces

    def _get_time_cov_name(self) -> str | None:
        for cov_name in self.covariate_names:
            if cov_name.lower() in ("group", "t", "time", "age", 'ind_chron_t', 'ind_chron_t_years'):
                return cov_name
        return None

    def _params_to_cholesky(self, raw: torch.Tensor, min_diag: float = 1e-4, dim: int = None):
        """
        Convert raw parameters to Cholesky decomposition L and covariance Σ=LLᵀ.
        Supports both 2D and 3D cases.
        
        Args:
            raw: [B,N,n_params] where n_params = dim*(dim+1)//2
                 For dim=2: [B,N,3] -> [l11, l22, l21]
                 For dim=3: [B,N,6] -> [l11, l22, l33, l21, l31, l32]
            min_diag: Minimum diagonal value for numerical stability (default: 1e-4)
            dim: Dimensionality (2 or 3). If None, infer from raw.shape[-1]
        
        Returns:
            L: [B,N,dim,dim] lower triangular Cholesky factor (diag > 0)
            Sigma: [B,N,dim,dim] covariance matrix Σ = LLᵀ
        """
        # Infer dimensionality from input size if not provided
        if dim is None:
            n_params = raw.shape[-1]
            # Solve: n_params = dim * (dim + 1) / 2
            # For dim=2: 3 params, dim=3: 6 params
            if n_params == 3:
                dim = 2
            elif n_params == 6:
                dim = 3
            else:
                raise ValueError(f"Invalid number of parameters: {n_params}. Expected 3 (dim=2) or 6 (dim=3).")
        
        zero = torch.zeros_like(raw[..., 0:1])
        
        if dim == 2:
            # For 2D: raw = [l11, l22, l21]
            l11 = torch.nn.functional.softplus(raw[..., 0:1]) + min_diag  # [B,N,1]
            l22 = torch.nn.functional.softplus(raw[..., 1:2]) + min_diag  # [B,N,1]
            l21 = raw[..., 2:3]                                            # [B,N,1]
            
            # Construct lower triangular matrix [B,N,2,2]
            L = torch.stack([
                torch.cat([l11, zero], dim=-1),           # [l11, 0]
                torch.cat([l21, l22], dim=-1),            # [l21, l22]
            ], dim=-2)  # [B,N,2,2]
            
        elif dim == 3:
            # For 3D: raw = [l11, l22, l33, l21, l31, l32]
            l11 = torch.nn.functional.softplus(raw[..., 0:1]) + min_diag  # [B,N,1]
            l22 = torch.nn.functional.softplus(raw[..., 1:2]) + min_diag  # [B,N,1]
            l33 = torch.nn.functional.softplus(raw[..., 2:3]) + min_diag  # [B,N,1]
            l21 = raw[..., 3:4]                                            # [B,N,1]
            l31 = raw[..., 4:5]                                            # [B,N,1]
            l32 = raw[..., 5:6]                                            # [B,N,1]
            
            # Construct lower triangular matrix [B,N,3,3]
            L = torch.stack([
                torch.cat([l11, zero, zero], dim=-1),     # [l11, 0,   0]
                torch.cat([l21, l22,  zero], dim=-1),     # [l21, l22, 0]
                torch.cat([l31, l32,  l33],  dim=-1),     # [l31, l32, l33]
            ], dim=-2)  # [B,N,3,3]
        else:
            raise ValueError(f"Unsupported dimensionality: {dim}. Expected 2 or 3.")
        
        # Compute covariance: Σ = LLᵀ
        Sigma = L @ L.transpose(-1, -2)
        
        return L, Sigma
    



    def predict_obs_noise(self, coords: torch.Tensor, dict_covariates: dict, time_cov_name: str):
        """
        学习得到各点 Σ(x,t)。输入与 deformation 一样：coords_encoded + 当前 t
        
        🔥 训练时：对 t 添加小抖动以提高鲁棒性
        
        返回:
        L: [B,N,3,3], Sigma: [B,N,3,3]
        """
        t = dict_covariates[time_cov_name]
        
        # # 🔥 训练时添加时间抖动
        # if self.training:
        #     # 添加小的高斯噪声 (std = 0.05, 约±0.1年的扰动)
        #     t_noise = torch.randn_like(t) * 0.05
        #     t = t + t_noise
        
        coords_enc = self.encode_coord(coords)
        inp = torch.cat([coords_enc, t], dim=-1)  # [B,N,in_dim+1]
        inp_zero = torch.cat([coords_enc, torch.zeros_like(t)], dim=-1)  # [B,N,in_dim+1]
        raw = self.noise_net(inp) - self.noise_net(inp_zero) + self.noise_residual_net(coords_enc)               # [B,N,6]
        L, Sigma = self._params_to_cholesky(raw, self.fisher_eps)
        return L, Sigma


    def gaussian_nll_from_chol(self, resid: torch.Tensor, L: torch.Tensor):
        """
        resid: [B,N,3] = (y_obs - y_pred)
        L:     [B,N,3,3]  下三角 (Σ=LLᵀ)
        返回逐点 NLL: [B,N,1]
        公式: eᵀΣ^{-1}e + log|Σ| ；用 Cholesky 求解避免显式逆
        """
        e = resid.unsqueeze(-1).detach()                    # [B,N,3,1]
        # 解 LLᵀ z = e  ->  z = Σ^{-1} e
        z = torch.cholesky_solve(e, L, upper=False)                # [B,N,3,1]
        quad = (e * z).sum(dim=(-2, -1), keepdim=False).unsqueeze(-1)  # eᵀΣ^{-1}e  [B,N,1]
        logdet = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(dim=-1, keepdim=True)  # [B,N,1]
        return 0.5 * (quad + logdet)


    # def gaussian_nll_from_chol(self, resid: torch.Tensor, L: torch.Tensor, detach_mean: bool = False):
    #     """
    #     resid: [B,N,3] = (y_obs - y_pred)
    #     L:     [B,N,3,3] 下三角 (Σ=LLᵀ)
    #     返回逐点 NLL: [B,N,1]
    #     """
    #     e = resid.unsqueeze(-1)  # [B,N,3,1]
    #     if detach_mean:
    #         e = e.detach()
        
    #     # 解 LLᵀ z = e  ->  z = Σ^{-1} e
    #     z = torch.cholesky_solve(e, L, upper=False)  # [B,N,3,1]
    #     quad = (e * z).sum(dim=(-2, -1), keepdim=True)  # [B,N,1]
        
    #     # log|Σ| = 2 * sum(log(diag(L)))
    #     logdet = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(dim=-1, keepdim=True)
        
    #     # 完整 NLL（包含 1/2 因子）
    #     return 0.5 * (quad + logdet)  # 可选：+ 0.5 * d * log(2π)


    def compute_time_sensitivity(self, coords: torch.Tensor, dict_covariates: dict, time_cov_name: str):
        t = dict_covariates[time_cov_name].clone().detach().requires_grad_(True)
        dict_covs = {k: v for k, v in dict_covariates.items()}
        dict_covs[time_cov_name] = t
        with torch.enable_grad():
            disp_t = self.calculate_cov_wise_disp(coords, dict_covs, time_cov_name)  # [B,N,3]

        grads = []
        for c in range(disp_t.shape[-1]):  # 三个通道
            g_c = torch.autograd.grad(
                outputs=disp_t[..., c].sum(),
                inputs=t,
                retain_graph=True,
                create_graph=False,
                allow_unused=True
            )[0]
            # if g_c is None:
            #     g_c = torch.zeros_like(t)
            grads.append(g_c)

        v = torch.stack(grads, dim=-1).squeeze(-2).detach()  # [B,N,3]
        v_norm2 = (v**2).sum(dim=-1, keepdim=True)          # [B,N,1]
        return v, v_norm2

       

    def compute_time_sensitivity_smooth(self, coords, dict_covariates, time_cov, 
                                        sigma_delta=0.15, K=3):
        # 平滑的中心差分近似，参与反传
        t0 = dict_covariates[time_cov]
        grads = []
        for _ in range(K):
            delta = torch.randn_like(t0) * sigma_delta
            def disp_at(tv):
                dc = {k: v for k, v in dict_covariates.items()}
                dc[time_cov] = tv
                return self.calculate_cov_wise_disp(coords, dc, time_cov)  # [B,N,3]
            f_plus  = disp_at(t0 + delta).detach()   
            f_minus = disp_at(t0 - delta).detach()   
            g = (f_plus - f_minus) / (2.0 * delta.clamp_min(1e-6))         # [B,N,3]
            grads.append(g)
        v_smooth = torch.stack(grads, dim=0).mean(dim=0).detach()                    # [B,N,3]
        v_norm2  = (v_smooth**2).sum(dim=-1, keepdim=True).detach()                     # [B,N,1]
        return v_smooth, v_norm2


    def compute_fisher_variance_map(self, model_input: dict, secondary: bool = True):
        coords = model_input['coords']
        dict_covariates = model_input['covariates']

        # 1) 定位 time 协变量
        time_cov = None
        # for k in self.covariate_names:
        #     if k.lower() in ('t','time','age', 'ind_chron_t', 'group'):
        #         time_cov = k; break
        # if time_cov is None:
        #     raise ValueError("No time-like covariate found.")
        time_cov = self._get_time_cov_name()
        #print(dict_covariates)
        # 2) v = ∂f/∂t
        v, v_norm2 = self.compute_time_sensitivity(coords, dict_covariates, time_cov)  # [B,N,3], [B,N,1]

        # 3) Σ：优先用学到的；若 learn_anisotropic_noise=False，则回退到各向同性常数
        L, Sigma = self.predict_obs_noise(coords, dict_covariates, time_cov)       # [B,N,3,3]

        I = torch.eye(self.in_features, device=L.device, dtype=L.dtype).view(1,1,self.in_features,self.in_features)
        # self.safe_inv_eps = 1e-8
        # Leps = torch.linalg.cholesky(Sigma + self.safe_inv_eps * I)  # 用它做 solve

        # 计算 vᵀΣ^{-1}v ：用 cholesky_solve
        b = v.unsqueeze(-1)                           # [B,N,3,1]
        #b = torch.ones_like(v.unsqueeze(-1)) 
        Sigma_inv_v = torch.cholesky_solve(b, L)      # [B,N,3,1]
        I_mean_map = (b * Sigma_inv_v).sum(dim=(-2,-1), keepdim=False).unsqueeze(-1) # [B,N,1]


        # 4) 可选协方差项 I_cov
        if secondary:
            Sigma_val, Sigma_t = self.sigma_and_sigma_t_direct(coords, dict_covariates, time_cov)
            self.safe_inv_eps = 1e-8
            I = torch.eye(self.in_features, device=L.device, dtype=L.dtype).view(1,1,self.in_features,self.in_features)
            #Leps = torch.linalg.cholesky(Sigma + self.safe_inv_eps * I)  # 用它做 solve


            Sigma_t = 0.5 * (Sigma_t + Sigma_t.transpose(-1, -2))         # [B,N,3,3]
            Sigma_inv = torch.cholesky_inverse(L)                          # [B,N,3,3]
            I_cov_map = 0.5 * torch.einsum(
                '...ij,...jk,...kl,...li->...',
                Sigma_inv, Sigma_t, Sigma_inv, Sigma_t
            ).unsqueeze(-1)    
            
            # 现在的原始写法
            # M = Sigma_inv @ Sigma_t @ Sigma_inv @ Sigma_t  # [B,N,3,3]
            # I_cov_map = 0.5 * torch.diagonal(M, dim1=-2, dim2=-1).sum(dim=-1, keepdim=True)  # tr(M), [B,N,1]
        else:
            I_cov_map = torch.zeros_like(I_mean_map)

        # 5) 合成 & 截断（逐点，不取平均）
        I_map = ( I_mean_map).clamp_min(self.fisher_eps)          #(I_mean_map + I_cov_map).clamp_min(self.fisher_eps)                   # [B,N,1]

        fisher_info = I_map
        fisher_var = 1.0 / fisher_info

        return {
            "fisher_var": fisher_var,
            "fisher_info": fisher_info.detach(),
            "time_grad": v.detach(),
            "v_norm2": v_norm2.detach(),
        }
    

    # def fisher_cov_term_from_cholesky(self, L, Sigma_t):
    #     """
    #     L:        [B,N,3,3]  cholesky(Σ)（下三角）
    #     Sigma_t:  [B,N,3,3]  ∂Σ/∂t 的近似（比如中心差分）
    #     return:   [B,N,1]    0.5 * tr(Σ^{-1}Σ_t Σ^{-1}Σ_t)
    #     """
    #     # C = L^{-T} Σ_t L^{-1}  （对称相似变换）
    #     # 先解 L X = Σ_t   -> X = L^{-1} Σ_t
    #     X = torch.linalg.solve_triangular(L, Sigma_t, upper=False)
    #     # 再解 L^T C^T = X^T -> C = L^{-T} Σ_t L^{-1}
    #     C = torch.linalg.solve_triangular(L.transpose(-1, -2), X, upper=True)
    #     # 0.5 * ||C||_F^2  （C 对称时 tr(C^2)=||C||_F^2；数值上用 Frobenius 更稳）
    #     I_cov = 0.5 * (C * C).sum(dim=(-2, -1), keepdim=False).unsqueeze(-1)  # [B,N,1]
    #     return I_cov


    def sigma_and_sigma_t_direct(self, coords, covs, time_name):
        t = covs[time_name].detach().clone().requires_grad_(True)
        xenc = self.encode_coord(coords)
        is_training = self.noise_net.training
        self.noise_net.eval()
        self.noise_residual_net.eval()
        # raw  = self.noise_net(torch.cat([xenc, t], -1))              # [..., 6]
        # L, Sigma = self._params_to_cholesky(raw, min_diag=self.fisher_eps)  # SPD

        # 目标：Σ_t 逐点一阶导（同形状 [B,N,3,3]）
        # 用 autograd 的 JVP 更高效：对函数 F(t)=Σ(x,t) 求沿单位方向的方向导数
        from functorch import jvp   # 若用 PyTorch 2.1+ 可用 torch.func.jvp
        def F(tt):
            raw = self.noise_net(torch.cat([xenc, tt], -1)) - self.noise_net(torch.cat([xenc, torch.zeros_like(tt)], -1)) + self.noise_residual_net(xenc)               # [B,N,6]
            L, S = self._params_to_cholesky(raw, min_diag=self.fisher_eps)
            return S
        Sigma_val, Sigma_t = jvp(F, (t,), (torch.ones_like(t),))      # JVP: Σ_t
        self.noise_net.train(is_training)
        self.noise_residual_net.train(is_training)

        # self.safe_inv_eps = 1e-4
        # I = torch.eye(self.in_features, device=t.device, dtype=t.dtype).view(1,1,self.in_features,self.in_features)
        # Sigma_t = Sigma_t + self.safe_inv_eps * I
        return Sigma_val.detach(), Sigma_t.detach()



    def predict_sigma(self, coords: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Predict per-point variance (sigma²) at given coordinates and time.
        
        This is the inverse of Fisher information: σ²(x,t) = 1/fisher_info(x,t)
        
        Args:
            coords: [B, N, 3] or [1, N, 3] spatial coordinates
            t: [B, N, 1] or [1, N, 1] time values (covariates)
        
        Returns:
            sigma2: [B, N, 1] or [1, N, 1] variance at each point
        """
        num_points = coords.shape[1]
        
        # Handle different input shapes for t
        if t.dim() == 1:
            t = t.unsqueeze(-1)  # [B] -> [B, 1]
        if t.dim() == 2:
            # [B, 1] -> [B, N, 1]
            t = t.unsqueeze(1).expand(-1, num_points, -1)
        
        # Create a model_input dict for compute_fisher_variance_map
        model_input = {
            'coords': coords,
            'covariates': {
                self._get_time_cov_name(): t
            }
        }
        
        # Compute Fisher variance (which is already 1/fisher_info)
        fisher_map = self.compute_fisher_variance_map(model_input)
        
        # fisher_var is already σ²(x,t)
        sigma2 = fisher_map['fisher_var']  # [B,N,1]
        #sigma2 = sigma2.clamp_max(1.0)

        return sigma2


    def predict_global_sigma(self, t: torch.Tensor, top_k_ratio: float = 0.95) -> torch.Tensor:
        """
        Compute global variance using top-k fisher information weighted average.
        
        The global variance is computed as the weighted average of local variances,
        where weights are the top-k fisher information values (most confident predictions).
        
        Args:
            coords: [B, N, 3] spatial coordinates
            t: [B, N, 1] time values (covariates)
            top_k_ratio: Ratio of top points to use (default: 0.1 = top 10%)
        
        Returns:
            torch.Tensor: Global variance [B, 1]
        """
        # Create model_input for compute_fisher_variance_map
        model_input = {
            'coords': self.template_mesh[None, ...],
            'covariates': {
                self._get_time_cov_name(): t
            }
        }
        model_input = self.prepare_model_input(model_input)
        # Compute Fisher variance map
        fisher_map = self.compute_fisher_variance_map(model_input)
        fisher_info = fisher_map['fisher_info']  # [B, N, 1]
        fisher_var = fisher_map['fisher_var']    # [B, N, 1] = 1/fisher_info
        
        # Get number of points
        num_points = fisher_info.shape[-2]
        top_k = max(1, int(num_points * top_k_ratio))
        
        # Squeeze the last dimension for easier processing
        fisher_info_squeezed = fisher_info.squeeze(-1)  # [B, N]
        variance_squeezed = fisher_var.squeeze(-1)      # [B, N]
        
        # Get top-k fisher information indices (most confident predictions = highest fisher info)
        topk_values, topk_indices = torch.topk(fisher_info_squeezed, k=top_k, dim=-1)  # [B, top_k]
        
        # Gather corresponding variances
        topk_variances = torch.gather(variance_squeezed, dim=-1, index=topk_indices)  # [B, top_k]
        
        # Compute Fisher-weighted average: global_var = sum(fi * var_i) / sum(fi)
        # where fi is fisher information (weight) and var_i is local variance
        weighted_sum = (topk_values * topk_variances).sum(dim=-1, keepdim=True)  # [B, 1]
        weight_sum = topk_values.sum(dim=-1, keepdim=True)  # [B, 1]

        global_sigma = weighted_sum / (weight_sum + 1e-8)  # [B, 1]

        return global_sigma  # Return variance [B, 1]


    def template_sdf(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Computes the template SDF using the teacher model.
        
        Args:
            coords (torch.Tensor): Input coordinates.
            
        Returns:
            torch.Tensor: Template SDF values.
        """
        with torch.no_grad():
            return self.teacher.template(coords)

    def concatenate_coords_with_cov(self, coords_encoded: torch.Tensor, covariate_value: torch.Tensor) -> torch.Tensor:
        """Concatenates encoded coordinates with covariate values."""
        return torch.cat([coords_encoded, covariate_value], dim=-1)

    def calculate_cov_wise_disp(self, coords_init: torch.Tensor, dict_covariates: dict, current_covariate_name: str) -> torch.Tensor:
        """
        使用标准方法计算特定协变量的deformation（与Student_v8相同）
        Student网络架构保持不变，但从landmark-aware teacher学习
        
        Args:
            coords_init: 初始坐标 [B, N, 3]
            dict_covariates: 协变量字典
            current_covariate_name: 当前协变量名称
            
        Returns:
            torch.Tensor: 当前协变量导致的deformation
        """
        current_covariate_value = dict_covariates[current_covariate_name]

        coords_encoded = coords_init #self.encode_coord(coords_init)
        coords_with_cov = self.concatenate_coords_with_cov(coords_encoded, current_covariate_value)
        coords_with_zeros = self.concatenate_coords_with_cov(coords_encoded, torch.zeros_like(current_covariate_value))
        deformation = self.cov_nets[current_covariate_name](coords_with_cov) - self.cov_nets[current_covariate_name](coords_with_zeros)
        residual = self.geometry_residual_net(coords_encoded)
        total_deformation = deformation + residual
        # scale = torch.nn.functional.softplus(self.scale(current_covariate_value))
        # total_deformation = deformation * 2 * scale
        # total_deformation = total_deformation - coords_init
        return total_deformation

    def deformation_from_covariates(self, model_input: dict) -> tuple[torch.Tensor, dict]:
        """
        Computes landmark-aware deformation from all covariates.
        deformation = (predicted_deformation + input_coords) * landmark_scale - input_coords
        """
        dict_covariates = model_input['covariates']
        coords = model_input['coords']

        dict_vf = {}
        deform_from_covs = torch.zeros_like(coords)
        for ith_cov_name in self.covariate_names:
            current_deform = self.calculate_cov_wise_disp(coords, dict_covariates, ith_cov_name)
            dict_vf[ith_cov_name] = current_deform
            deform_from_covs += current_deform

        return deform_from_covs, dict_vf







    def prepare_input_for_teacher(self, model_input: dict) -> dict:
        """
        Prepares the input dictionary for the teacher model with deformed coordinates.
        """
        #teacher_input = copy.deepcopy(model_input)
        teacher_input = model_input
        teacher_input['idx'] = model_input['teacher_idx']
        return teacher_input
    
    def predict_teacher_deformation(self, model_input: dict) -> torch.Tensor:
        """
        预测teacher deformation并直接应用 landmark_scale / reference_scale 缩放
        """
        teacher_input = self.prepare_input_for_teacher(model_input)
        coords = model_input['coords']
        teacher_deform = self.teacher.backward_deformation(teacher_input)
        landmark_scale = self.landmark_rescaler.extract_landmark_scale(teacher_input)
        coords_at_ori_space = self.landmark_rescaler.apply_rescaling_to_coords(teacher_deform + coords ,landmark_scale=landmark_scale)
        teacher_deform = coords_at_ori_space - coords

        return teacher_deform

    def get_final_coords(self, coords: torch.Tensor, dict_deforms: dict, model_input: dict) -> torch.Tensor:
        """
        获取最终坐标（直接相加）
        """
        final_coords = coords + dict_deforms["overall"]
        return final_coords

    def predict_deformations(self, model_input: dict) -> dict:
        """
        Predicts deformations from both teacher and student models.
        """
        teacher_deform = self.predict_teacher_deformation(model_input)
        student_deform, dict_disp = self.deformation_from_covariates(model_input)
        dict_disp["teacher"] = teacher_deform 
        dict_disp["overall"] = student_deform
        return dict_disp

    def predict_deformations_as_teacher(self, model_input: dict) -> dict:
        """
        Predicts deformations for teacher mode (used by InverseEncoder).
        """
        student_deform, dict_disp = self.deformation_from_covariates(model_input)
        dict_disp["overall"] = student_deform
        return dict_disp

    def forward(self, model_input: dict) -> dict:
        """
        Forward pass with landmark-aware teacher guidance, optional anisotropic noise learning,
        and Fisher uncertainty map output.
        """
        
        # ===== 0) 准备输入（保持你的原始逻辑） =====
        model_input_copy = model_input  # 如果你担心经常性in-place，换成 deepcopy 也行
        model_input_copy = self.prepare_model_input(
            model_input=model_input_copy,
            pts_on_template=(torch.rand_like(model_input["coords"]) - 0.5) * self.sampling_range * 2.0
        )

        # ===== 1) 学生/教师形变预测（保持你的原始逻辑） =====
        dict_deforms = self.predict_deformations(model_input_copy)   # {'overall': student, 'teacher': teacher, ...}
        coords = model_input_copy['coords']

        model_output_student = {
            'model_in': coords,
            'all_input': coords,
            'vec_fields': dict_deforms,                 # deformation 字典（含 overall / teacher / 各协变量）
            'template': self.template_sdf(coords),
        }



        # ===== 3) （可选）学习各向异性噪声 Σ(x,t) 并输出逐点NLL =====
        # 开关由 __init__ 中的 self.learn_anisotropic_noise 控制
        if getattr(self, "learn_anisotropic_noise", False):
            time_cov_name = self._get_time_cov_name()
            if time_cov_name is not None:
                # 3.1 预测各点的 Cholesky (L) 与协方差 Sigma
                L, Sigma = self.predict_obs_noise(coords, model_input_copy['covariates'], time_cov_name)  # [B,N,3,3]

                # 3.2 观测残差：优先使用真值形变；否则用 teacher 当代理观测
                y_obs = dict_deforms["teacher"] + model_input_copy['coords']              # [B,N,3]
                resid = dict_deforms["teacher"].detach() - dict_deforms["overall"]                # [B,N,3]

                # 3.3 高斯 NLL（逐点 + 全局）
                nll_map = self.gaussian_nll_from_chol(resid, L)        # [B,N,1]
                nll = nll_map.mean()                                   # 标量，用于训练

                model_output_student.update({
                    "noise_chol": L,                # [B,N,3,3]
                    "noise_cov": Sigma,             # [B,N,3,3]
                    "noise_nll_map": nll_map,       # [B,N,1]
                    "noise_nll": nll,               # 标量（平均）
                })
            else:
                # 没有时间协变量时，跳过噪声学习分支
                model_output_student.update({
                    "noise_nll": torch.tensor(0.0, device=coords.device),
                })

        # ===== 4) （可选）输出 Fisher 不确定性地图 =====
        # 开关由 __init__ 中的 self.predict_fisher_uncertainty 控制
        if getattr(self, "predict_fisher_uncertainty", False):
            random_model_input = self.sample_random_input(batch_size=coords.shape[0])  # 随机采样输入
            fisher = self.compute_fisher_variance_map(random_model_input)  # 使用学到的 Σ（若已开启），否则回退
            model_output_student.update({
                "fisher_var":  fisher["fisher_var"],     # [B,N,1]  点级 Var(t|x)
                "fisher_info": fisher["fisher_info"],     # [B,N,1]  v^T Σ^{-1} v
                # 可选诊断：
                "time_grad":   fisher.get("time_grad", None),  # [B,N,3]  ∂f/∂t
                "v_norm2":     fisher.get("v_norm2", None),    # [B,N,1]  ||∂f/∂t||^2
            })

        return model_output_student



    def sample_random_input(self, bbox_scale: float = 1.5, batch_size: int = 64) -> dict:
        """
        Randomly samples spatial coordinates from the template mesh bounding box,
        supporting both 2D and 3D shapes.
        
        Args:
            bbox_scale: Scale factor to expand bounding box (default: 1.5)

        Returns:
            dict: Randomly sampled input containing 'coords' and 'covariates'.
        """
        self.num_of_points = 1500
        # Get bounding box from template mesh
        bbox_min = self.template_mesh.min(dim=0)[0]  # [2/3]
        bbox_max = self.template_mesh.max(dim=0)[0]  # [2/3]
        dim = bbox_min.shape[-1]  # 2 or 3
        
        # Expand bounding box with scale factor
        bbox_center = (bbox_min + bbox_max) / 2  # [2/3]
        bbox_half_size = (bbox_max - bbox_min) / 2  # [2/3]
        bbox_min_expanded = bbox_center - bbox_half_size * bbox_scale  # [2/3]
        bbox_max_expanded = bbox_center + bbox_half_size * bbox_scale  # [2/3]
        
        # Sample points uniformly within expanded bounding box
        coords = []
        for d in range(dim):
            # [B*N, 1]
            coord_d = torch.rand(batch_size * self.num_of_points, 1).cuda() * \
                     (bbox_max_expanded[d] - bbox_min_expanded[d]) + bbox_min_expanded[d]
            coords.append(coord_d)
        coords = torch.cat(coords, dim=-1)  # [B*N, 2/3]
        coords = coords.reshape(batch_size, self.num_of_points, -1)  # [B, N, 2/3]

        # Generate random covariates
        covariates = {name: ((torch.rand((batch_size, self.num_of_points, 1)).cuda() - 0.5) * 6) 
                     for name in self.covariate_names}

        return {"coords": coords, "covariates": covariates}
    

    def forward_as_teacher(self, model_input: dict) -> dict:
        """
        Forward pass for teacher mode (used by InverseEncoder).
        与Student_v8保持完全兼容
        """
        model_input_copy = model_input #copy.deepcopy(model_input)
        dict_deforms = self.predict_deformations_as_teacher(model_input_copy)

        model_output_student = {
            'model_in': model_input_copy['coords'],
            'all_input': model_input_copy['coords'],
            'vec_fields': dict_deforms,
            'template': self.template_sdf(model_input_copy['coords']),
        }

        return model_output_student



    def inference(self, model_input_ori: dict):
        """
        🔥 Enhanced Forward pass with SDF computation and landmark rescaling
        """
        model_input = copy.deepcopy(model_input_ori)
        model_input = self.prepare_model_input(model_input=model_input)

        #dict_deforms = self.predict_deformations(model_input)
        student_deform, dict_deforms = self.deformation_from_covariates(model_input)
        dict_deforms["overall"] = student_deform

        model_output_student = {}
        model_output_student.update({
            'model_in': model_input['coords'],
            'all_input': model_input['coords'],
            'vec_fields': dict_deforms,
            'template': self.template_sdf(model_input['coords']),
        })
        if 'teacher_idx' in model_input:
            # 🔥 使用封装的teacher SDF计算
            input_for_teacher = self.update_input_coords(
                model_input, (model_input['coords'] + dict_deforms["overall"]).detach().clone()
            )
            teacher_input = self.prepare_input_for_teacher(input_for_teacher)
            model_out_mean = self.teacher.inference(teacher_input)["model_out"]
            model_output_student.update({"model_out": model_out_mean})
            
        coords_deformed = model_input['coords'] + dict_deforms["overall"]
        model_output_student.update({"output": coords_deformed})
        return model_output_student
    


    def inference_with_template(self, model_input_ori: dict):
        """
        🔥 Enhanced Forward pass with SDF computation and landmark rescaling
        """
        model_input = copy.deepcopy(model_input_ori)
        model_input = self.prepare_model_input(model_input=model_input, pts_on_template=self.template_mesh[None, ...], retain=True)

        #dict_deforms = self.predict_deformations(model_input)
        student_deform, dict_deforms = self.deformation_from_covariates(model_input)
        dict_deforms["overall"] = student_deform

        model_output_student = {}
        model_output_student.update({
            'model_in': model_input['coords'],
            'all_input': model_input['coords'],
            'vec_fields': dict_deforms,
            'template': self.template_sdf(model_input['coords']),
        })
        if 'teacher_idx' in model_input:
            # 🔥 使用封装的teacher SDF计算
            input_for_teacher = self.update_input_coords(
                model_input, (model_input['coords'] + dict_deforms["overall"]).detach().clone(), retain=True
            )
            teacher_input = self.prepare_input_for_teacher(input_for_teacher)
            model_out_mean = self.teacher.inference(teacher_input)["model_out"]
            model_output_student.update({"model_out": model_out_mean})
            
        coords_deformed = model_input['coords'] + dict_deforms["overall"]
        model_output_student.update({"output": coords_deformed})
        return model_output_student


    
    

    def inference_for_optimize(self, model_input: dict):
        """
        🔥 Enhanced Forward pass with SDF computation and landmark rescaling
        """
        model_input = self.prepare_model_input(model_input=model_input, retain=True)

        dict_deforms = self.predict_deformations(model_input)

        model_output_student = {}
        model_output_student.update({
            'model_in': model_input['coords'],
            'all_input': model_input['coords'],
            'vec_fields': dict_deforms,
            'template': self.template_sdf(model_input['coords']),
        })

        # 🔥 使用封装的teacher SDF计算
        input_for_teacher = self.update_input_coords(
            model_input, (model_input['coords'] + dict_deforms["overall"]).detach().clone(), retain=True
        )
        teacher_input = self.prepare_input_for_teacher(input_for_teacher)
        model_out_mean = self.teacher.inference(teacher_input)["model_out"]
        model_output_student.update({"model_out": model_out_mean})
        coords_deformed = model_input['coords'] + dict_deforms["overall"]
        model_output_student.update({"output": coords_deformed})
        return model_output_student
    

    def calculate_inv_displacement(self, model_input: dict) -> torch.Tensor:
        """
        Computes inverse displacement from covariates.
        与Student_v8保持完全兼容
        """
        deformation_from_cov, _ = self.deformation_from_covariates(model_input)
        return deformation_from_cov

    def inv_transform(self, model_input_ori: dict) -> torch.Tensor:
        """
        🔥 Enhanced inverse transformation with landmark rescaling
        """
        model_input = copy.deepcopy(model_input_ori)
        model_input = self.prepare_model_input(model_input=model_input)

        vf = self.calculate_inv_displacement(model_input)
        transformed_p = vf + model_input['coords']
    
        return transformed_p

