#!/usr/bin/env python3
"""
Student_v9: Student with Landmark-Aware Teacher Training

This module extends Student_v8 by using a landmark-aware teacher to train the student,
then applying landmark rescaling during inference for correct scale output.

Key Enhancement over Student_v8:
- Uses landmark-aware teacher for training (teacher handles landmark scaling internally)
- Student architecture remains unchanged (learns from better teacher)
- Applies landmark rescaling during inference/reconstruction
- Maintains full compatibility with Student_v8 infrastructure

Architecture:
- LandmarkAwareStudentForward_v9: Enhanced student forward with landmark-aware teacher
- Student_v9: Complete model with InverseEncoder + TimePrior (same as v8)
- Landmark rescaling applied at output stage for correct scaling

This allows the student to learn from a more robust teacher while maintaining 
correct coordinate scaling for downstream applications.
"""
from __future__ import annotations
import sys
sys.path.append("..")  # Adjust the path to import from the parent directory
import torch
import copy
import logging
from torch import nn
from typing import Dict, Tuple

# Import teacher from current package (avoid naisr dependencies for now)
from .teacher import Teacher
from .base_net import BaseNet

from ..networks import Siren  # Changed import to Siren directly
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from .base_net import BaseNet, extract_mesh, SinePE
from .inverse_encoder import InverseEncoder
from .time_prior import TimePriorWithSpatialSigma
#from .student_deformation import LandmarkAwareStudentForward_v9
from .trajectory_estimator import TrajectoryEstimator
from .trajectory_estimator_stable import StableTrajectoryEstimator
OMEGA_0 = 30  # Constant for Siren initialization





class Student_v11(nn.Module):
    """
    Student_v11: Student with Landmark-Aware Teacher Training

    This class extends Student_v10 architecture with landmark-aware teacher training:
    - LandmarkAwareStudentForward_v10: Enhanced student forward with landmark-aware teacher
    - Same InverseEncoder and TimePrior as Student_v8
    - Landmark rescaling applied at inference time
    
    Key Benefits:
    - Student learns from a teacher that understands landmark scaling
    - Better training signal leads to improved deformation prediction
    - Correct coordinate scaling maintained for downstream applications
    - Full compatibility with existing Student_v8 training infrastructure
    """
    
    def __init__(self, student_cfg, inverse_encoder_cfg, time_prior_cfg):
        """
        Initializes the Student_v9 Model with landmark-aware teacher training.

        Args:
            student_cfg (dict): Configuration for LandmarkAwareStudentForward_v9.
            inverse_encoder_cfg (dict): Configuration for InverseEncoder.
            time_prior_cfg (dict): Configuration for TimePriorWithSpatialSigma.
        """
        super(Student_v11, self).__init__()

        # Initialize TrajectoryEstimator
        self.student_forward = StableTrajectoryEstimator(student_cfg) #StableTrajectoryEstimator(student_cfg) #

        # Initialize InverseEncoder with TrajectoryEstimator as the teacher
        self.inverse_encoder = InverseEncoder(inverse_encoder_cfg, teacher=self.student_forward)
        self.time_prior =  self.student_forward

        logger.info("✅ Initialized Student_v10 with Landmark-Aware Teacher Training")
        logger.info("   Enhanced features:")
        logger.info("   - Landmark-aware teacher provides better training signal")
        logger.info("   - Student learns from teacher that handles landmark scaling")
        logger.info("   - Landmark rescaling applied at inference time")
        logger.info("   - Full compatibility with Student_v8 infrastructure")

    @property
    def device(self):
        """Get the device of the model."""
        return next(self.student_forward.parameters()).device

    def template_sdf(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Computes the template SDF using the student_forward's template_sdf method.
        与Student_v8保持完全兼容
        """
        return self.student_forward.template_sdf(coords)

    def forward(self, model_input: dict) -> dict:
        """
        Forward pass for the combined model with landmark-aware teacher.
        与Student_v8保持完全兼容的接口
        """
        # Step 1: LandmarkAwareStudentForward_v9 predicts deformation fields
        student_output = self.student_forward(model_input)

        # Step 2: InverseEncoder predicts latent age (t) from deformation
        inverse_encoder_output = self.inverse_encoder.forward(model_input)

        # Combine outputs into a single dictionary
        combined_output = {
            "student_forward": student_output,
            "inverse_encoder": inverse_encoder_output,
            "time_prior": student_output["fisher_var"]
        }

        return combined_output

    def forward_inv(self, model_input: dict) -> dict:
        """
        Forward pass through the inverse encoder.
        与Student_v8保持完全兼容
        """
        return self.inverse_encoder.forward(model_input)

    def forward_as_teacher(self, model_input: dict) -> dict:
        """
        Forward pass as teacher for the inverse encoder.
        与Student_v8保持完全兼容
        """
        return self.student_forward.forward_as_teacher(model_input)

    def forward_prior(self, model_input: dict) -> dict:
        """
        Forward pass through the time prior.
        与Student_v8保持完全兼容
        """
        return self.time_prior.forward(model_input)

    def inv_transform(self, model_input: dict) -> torch.Tensor:
        """
        🔥 Enhanced inverse transformation with landmark rescaling
        
        Args:
            model_input (dict): Input dictionary containing 'coords' and 'covariates'.
            
        Returns:
            torch.Tensor: Transformed and rescaled coordinates.
        """
        return self.student_forward.inv_transform(model_input)

    def compute_total_loss_with_constraints(self,
                                          model_input: dict,
                                          primary_loss: torch.Tensor,
                                          constraint_weight: float = 0.1) -> dict:
        """
        计算包含时间膨胀约束的总损失（继承自Student_v8）
        
        Args:
            model_input: 模型输入
            primary_loss: 主要损失（重建损失等）
            constraint_weight: 约束损失权重
            
        Returns:
            dict: 包含总损失和分解的字典
        """
        # 获取时间膨胀约束损失
        constraint_results = self.student_forward.compute_temporal_constraint_loss(model_input)
        temporal_constraint_loss = constraint_results["temporal_constraint_loss"]
        
        # 总损失
        total_loss = primary_loss + constraint_weight * temporal_constraint_loss
        
        return {
            "total_loss": total_loss,
            "primary_loss": primary_loss,
            "temporal_constraint_loss": temporal_constraint_loss,
            "constraint_weight": constraint_weight,
            "temporal_diagnostics": constraint_results["temporal_diagnostics"]
        }

    def predict_global_time(
        self, 
        model_input: dict,
        time_cov_name: str = 'ind_chron_t',
        eps: float = 1e-8
    ) -> Dict[str, torch.Tensor]:
        """
        用 Fisher Information 加权平均 template 上所有点的 inverse encoder 预测，得到 global time。
        
        核心流程：
        1. 获取 template 上的点和 teacher 形变
        2. 用 inverse encoder 预测每个点的 latent time t_i
        3. 用 Fisher information (1/σ²) 作为权重
        4. 加权平均: t_global = Σ(w_i * t_i) / Σ(w_i), where w_i = 1/σ_i²
        
        Args:
            model_input: 包含 'coords' 和 'covariates' 的字典
            time_cov_name: 时间协变量名称（默认 'ind_chron_t'）
            eps: 数值稳定性参数
            
        Returns:
            dict: {
                "global_time": [B, 1] - Fisher加权的全局时间预测,
                "local_times": [B, N, 1] - 每个点的局部时间预测,
                "fisher_weights": [B, N, 1] - 每个点的Fisher权重 (1/σ²),
                "sigma": [B, N, 1] - 每个点的预测不确定性,
                "gt_time": [B, 1] - GT时间（如果有）
            }
        """
        
        # Step 1: 获取 template 点和 teacher 形变
        template_data = self.template_points_with_original_covariates(model_input)
        template_coords = template_data["template_coords"]  # [B, N, 3]
        ground_truth_deformation = template_data["ground_truth_deformation"]  # [B, N, 3]
        
        B, N, _ = template_coords.shape
        
        # Step 2: 用 inverse encoder 预测每个点的 latent time
        predicted_covariates = self.inverse_encoder.predict_covariates(
            template_coords,
            ground_truth_deformation
        )
        local_times = predicted_covariates[time_cov_name]  # [B, N, 1]

        # Step 3: 用 local_times 计算 σ²（不用 GT covariate）
        predicted_variance = self.student_forward.predict_sigma(
            template_coords, 
            local_times  # 用 inverse encoder 预测的 local_times
        )  # [B, N, 1]
        
        sigma = torch.sqrt(predicted_variance)  # [B, N, 1]
        
        # Step 4: 过滤掉 σ > 1 的点，剩下的取简单平均
        sigma_threshold = 1.0
        
        # 对每个 batch 过滤并取平均
        global_times = []
        for b in range(B):
            sigma_b = sigma[b, :, 0]  # [N]
            local_t_b = local_times[b, :, 0]  # [N]
            
            # 过滤掉 σ > sigma_threshold 的点
            valid_mask = sigma_b <= sigma_threshold  # [N]
            
            if valid_mask.sum() > 0:
                # 对有效点取简单平均
                valid_times = local_t_b[valid_mask]  # [num_valid]
                global_t_b = valid_times.mean()  # scalar
            else:
                # 如果所有点都被过滤掉，退回到简单平均
                global_t_b = local_t_b.mean()
                print(f"⚠️  Batch {b}: All points have σ > {sigma_threshold}, using simple mean")
            
            global_times.append(global_t_b)
        
        global_time = torch.stack(global_times).unsqueeze(-1)  # [B, 1]
        
        # Fisher info 保留（用于诊断）
        fisher_info = 1.0 / (predicted_variance + eps)  # [B, N, 1]
        
        # Fisher weights 保留（用于诊断）
        fisher_weights = fisher_info  # [B, N, 1]
        
        # 获取 GT time（如果有）
        gt_time = None
        if time_cov_name in model_input.get('covariates', {}):
            gt_cov = model_input['covariates'][time_cov_name]
            if gt_cov.dim() == 3:
                gt_time = gt_cov[:, 0, :]  # [B, 1]
            else:
                gt_time = gt_cov  # [B, 1]
        
        # 计算简单平均（用于对比）
        simple_mean_time = local_times.mean(dim=1)  # [B, 1]
        
        return {
            "global_time": global_time,  # [B, 1] Fisher加权全局时间
            "local_times": local_times,  # [B, N, 1] 每个点的预测
            "fisher_weights": fisher_weights,  # [B, N, 1] Fisher权重
            "sigma": sigma,  # [B, N, 1] 不确定性
            "gt_time": gt_time,  # [B, 1] GT时间
            "simple_mean_time": simple_mean_time,  # [B, 1] 简单平均（对比用）
            "template_coords": template_coords,  # [B, N, 3] template坐标
        }

    def dict2array(self, dict_covariates: dict) -> torch.Tensor:
        """Convert covariates dictionary to array."""
        list_covariates = []
        for ith_cov_name in list(dict_covariates.keys()):
            current_covariate_value = dict_covariates[ith_cov_name]
            list_covariates.append(current_covariate_value)
        arr_covariates = torch.cat(list_covariates, dim=-1)
        return arr_covariates#.detach().clone()
    

    def template_points_with_original_covariates(self, model_input: dict) -> dict:
        """Sample template points while preserving original covariates."""
        # Get the actual batch size from input
        actual_batch_size = list(model_input['covariates'].values())[0].shape[0]
        
        # Sample template points once
        num_template_points = self.student_forward.template_mesh.shape[0]

        # Get sampled coordinates from template mesh [N, 3]
        template_coords_single = self.student_forward.template_mesh #[sampled_indices]  # [N, 3]
        
        # Expand to all batches [B, N, 3] - same points for all batch items
        template_coords = template_coords_single.unsqueeze(0).expand(
            actual_batch_size, -1, -1
        )  # [B, N, 3]

        # Preserve original covariates from the dataset
        original_covariates = copy.deepcopy(model_input['covariates'])
        
        # Expand covariates to match template points
        expanded_covariates = {}
        for cov_name, cov_value in original_covariates.items():
            if cov_value.shape[1] != num_template_points:
                expanded_cov = cov_value.unsqueeze(-1).expand(actual_batch_size, num_template_points, -1)
            else:
                expanded_cov = cov_value
            expanded_covariates[cov_name] = expanded_cov

        # Create high teacher model input
        high_teacher_model_input = {
            "coords": template_coords,
            "covariates": expanded_covariates,
            'idx': model_input.get('teacher_idx', None),
            'teacher_idx': model_input.get('teacher_idx', None)
        }

        # # Preserve metadata
        # if 'idx' in model_input:
        #     high_teacher_model_input['idx'] = model_input['idx']
        
        for key in model_input:
            if key not in ['coords', 'covariates', 'idx', 'teacher_idx']:
                high_teacher_model_input[key] = model_input[key]

        # Get ground truth deformation
        high_teacher_output = self.student_forward.teacher.backward_deformation(high_teacher_model_input)
        landmark_scale = self.student_forward.landmark_rescaler.extract_landmark_scale(high_teacher_model_input)
        coords_at_ori_space = self.student_forward.landmark_rescaler.apply_rescaling_to_coords(high_teacher_output + high_teacher_model_input['coords'], landmark_scale=landmark_scale)
        high_teacher_output = coords_at_ori_space - high_teacher_model_input['coords']

        return {
            "template_coords": template_coords,
            "original_covariates": expanded_covariates,
            "ground_truth_deformation": high_teacher_output,
            "high_teacher_model_input": high_teacher_model_input
         } 





    def compute_z_score_from_model_input(
        self,
        model_input: dict,
        gt_mesh: torch.Tensor,
        percentile: float = 0.01,
        z_threshold: float = -0.46,
        eps: float = 1e-8,
        use_residual_fusion: bool = True,
        lambda_residual: float = 0.5,
        use_fisher_weighting: bool = True
    ) -> Dict[str, torch.Tensor]:
        """
        改进的OOD检测方法：融合Teacher形变 + GT残差 + Fisher不确定性
        
        核心流程：
        1. 从gt['path_gt']读取GT网格并用landmarks缩放
        2. 获取template上的Teacher形变（期望）
        3. 用KD-Tree找gt_mesh与reconstructed_mesh的对应关系
        4. 计算残差（实际 - 期望）
        5. 融合形变：fused = teacher_deform + λ*residual
        6. 用InverseEncoder从融合形变预测年龄
        7. 年龄偏离 + Fisher加权 = OOD score
        
        Args:
            model_input: 包含 'coords' 和 'covariates' 的字典
            gt_mesh: [M, 3] Ground truth网格（已缩放对齐）
            percentile: quantile gap计算的百分位数
            z_threshold: z坐标ROI阈值
            eps: 数值稳定性参数
            use_residual_fusion: 是否融合残差项
            lambda_residual: 残差权重（0~1，推荐0.5）
            use_fisher_weighting: 是否用Fisher加权OOD score
            
        Returns:
            dict: {
                "ood_score": scalar - 最终OOD评分,
                "predicted_age": scalar - 预测年龄,
                "gt_age": scalar - GT年龄,
                "age_error": scalar - 年龄偏离,
                "fisher_var_mean": scalar - 平均Fisher方差,
                "quantile_gap": scalar - 传统quantile gap,
                "z_map": [masked_N] - 点级z-score（ROI内）,
                "residual_norm": scalar - 形变残差范数（诊断用）,
                "diagnostics": dict - 融合过程诊断信息
            }
        """

        with torch.enable_grad():


            # Step 1: Sample template points and preserve covariates
            template_data = self.template_points_with_original_covariates(model_input)
            template_coords = template_data["template_coords"]
            original_covariates = template_data["original_covariates"]
            ground_truth_deformation = template_data["ground_truth_deformation"]
            high_teacher_model_input = template_data["high_teacher_model_input"]

            # ========== Step 2: 获取GT年龄 ==========
            time_cov_name = 'age'
 

            # ========== Step 3.5: 用KD-Tree找最近点对应关系（关键） ==========
            # reconstructed_mesh [B, N, 3] vs model_input['coords'] [B, M, 3]
            # 目标：找到每个template点在观测网格中的最近点
            reconstructed_mesh = template_coords + ground_truth_deformation
            gt_mesh, _mask2 = _filter_points_by_z_range(gt_mesh, reconstructed_mesh)
            reconstructed_mesh, _mask1 = _filter_points_by_z_range(reconstructed_mesh, gt_mesh)

            template_coords = template_coords[_mask1].unsqueeze(0)
            ground_truth_deformation = ground_truth_deformation[_mask1].unsqueeze(0)

            residuals, arr_closest_reconstr_pts, indices = _compute_residual_via_kdtree(
                reconstructed_mesh.squeeze(),           # [B, N, 3] 目标：模板点
                gt_mesh,        # [B, N, 3] 源：重建网格（teacher预测）
            )  # [B, N, 3]

            ground_truth_deformation = ground_truth_deformation[:, indices[0], :]
            template_coords = template_coords[:, indices[0], :]
            deformation_residual = gt_mesh - arr_closest_reconstr_pts
            full_deformation =  ground_truth_deformation + deformation_residual

            t_pred_full = self.inverse_encoder.predict_covariates(
                template_coords,
                full_deformation
            )[time_cov_name]  # [B, N, 1]
            
            # 也预测teacher-only的年龄（用于诊断）
            model_input_preprocessed =  self.student_forward.prepare_model_input(model_input=model_input, pts_on_template=template_coords)
            arr_coords = model_input_preprocessed['coords']
            arr_covariates = model_input_preprocessed['covariates'][time_cov_name]
            predicted_variance = self.student_forward.predict_sigma(template_coords, arr_covariates)
            
            # ========== Step 7: 预测Fisher不确定性 ==========
            sigma = torch.sqrt(predicted_variance)  # [B, N, 1]
            fisher_var_mean = sigma.mean()  # scalar
            fisher_info = 1 / predicted_variance
            # ========== Step 8: 计算最终OOD评分 ==========

            # # 6) 计算点级 z-score
            z_threshold = -0.18 #0 #-0.1896 #-0.2339
            z_map = (t_pred_full - arr_covariates) / sigma #t_pred_full - global_develop_t #(t_pred_full - arr_covariates) / sigma  # [B, N, 1] #t_pred_full - global_develop_t #

            mask = (template_coords[..., [-1]] < z_threshold) & (template_coords[..., [-1]] > -0.72) 
            if mask.sum() == 0:
                mask = (template_coords[..., [-1]] <0.6282) & (template_coords[..., [-1]] > -0.72)
            
            max_z_score = z_map.max()
            z_map = z_map[mask].flatten()
            ood_score = ((t_pred_full[mask] - t_pred_full.max()) / sigma[mask]).min()
            #t_global_age = t_pred_full[sigma == sigma.min()]
            #ood_score = ((t_pred_full - t_global_age) / sigma)[mask].min()
            #ood_score = ((t_pred_full[mask] - arr_covariates.mean()) / sigma[mask]).min() #- ((t_pred_full[mask] - arr_covariates.mean()) / sigma[mask]).max()
            #ood_score = z_map.min() - z_map.quantile(0.99)
        return {"z_map": t_pred_full, "ood_score": ood_score}

def _compute_residual_via_kdtree(
        arr_template: torch.Tensor,         # [N_template, 3] 标准模板点
        reconstructed_mesh: torch.Tensor,   # [B, N, 3] 重建网格（teacher预测）
    ) -> torch.Tensor:
    """
    对于reconstructed_mesh上的每个点，找到arr_template中最近的点
    然后计算residual = observed_closest - reconstructed

    核心逻辑：
    1. 建立KD-Tree(arr_template) 
    2. 对每个reconstructed_mesh上的点 query最近邻
    3. 获得closest_template点
    4. 用这个closest_template点对应的observed点来计算残差

    流程图：
    reconstructed_b [N, 3]
            ↓
    query KD-Tree(arr_template)
            ↓
    closest_template_pts [N, 3]
            ↓
    query KD-Tree(observed_b)
            ↓
    closest_observed_pts [N, 3]
            ↓
    residual = closest_observed - reconstructed [N, 3]

    Args:
        arr_template: [N_template, 3] 标准模板点（用于构建KD-Tree）
        reconstructed_mesh: [B, N, 3] 重建网格（要查询的点）
        observed_mesh: [B, M, 3] 观测GT网格
        teacher_deformation: [B, N, 3] teacher的形变（用于诊断）
        
    Returns:
        residual_deformation: [B, N, 3] 形变残差
    """
    from scipy.spatial import cKDTree

    B, N, _ = reconstructed_mesh.shape
    residuals = []
    closest_template_pts_all = []
    indices = []
    # 将arr_template转到CPU并建立KD-Tree（只需建立一次）
    arr_template_np = arr_template.detach().cpu().numpy() if isinstance(arr_template, torch.Tensor) else arr_template
    kdtree_template = cKDTree(arr_template_np)  # [N_template, 3]

    for b in range(B):
        # 转到CPU进行KD-Tree操作
        reconstructed_b = reconstructed_mesh[b].detach().cpu().numpy()  # [N, 3] 要查询的点
        
        # ========== 关键步骤1: 找reconstructed上每个点在template中最近的点 ==========
        distances_to_template, template_indices = kdtree_template.query(reconstructed_b)
        # template_indices: [N] 每个reconstructed点最近的template点的索引
        # distances_to_template: [N] 距离
        closest_template_pts = arr_template_np[template_indices]  # [N, 3]
        
        # 转回GPU
        residuals.append(
            torch.tensor(
                distances_to_template,
                dtype=torch.float32,
                device=reconstructed_mesh.device
            )
        )
        closest_template_pts_all.append(
            torch.tensor(
                closest_template_pts,
                dtype=torch.float32,
                device=reconstructed_mesh.device
            )
        )
        indices.append(template_indices)

    residuals_stacked = torch.stack(residuals, dim=0)  # [B, N, 3]
    closest_template_pts_stacked = torch.stack(closest_template_pts_all, dim=0)  # [B, N, 3]
    return residuals_stacked, closest_template_pts_stacked, indices



def _filter_points_by_z_range(
    points: torch.Tensor,      # [B, N, 3] 要过滤的点
    template: torch.Tensor,    # [M, 3] 或 [B, M, 3] 模板点
    margin: float = 0.      # 允许的边界余量
) -> Tuple[torch.Tensor, torch.BoolTensor]:
    """
    过滤点云，确保z轴范围在template范围内（带margin）
    
    Args:
        points: [B, N, 3] 需要过滤的点云
        template: [M, 3] 或 [B, M, 3] 模板点云
        margin: float, 允许超出template范围的边界余量
        
    Returns:
        Tuple[torch.Tensor, torch.BoolTensor]:
            - filtered_points: [B, N', 3] 过滤后的点云
            - mask: [B, N] 布尔掩码，标记哪些点被保留
    """
    # 确保template形状正确
    if template.dim() == 2:
        template = template.unsqueeze(0)  # [1, M, 3]
    
    # 计算template的z轴范围
    z_min = template[..., 2].min(dim=1, keepdim=True)[0]  # [B, 1]
    z_max = template[..., 2].max(dim=1, keepdim=True)[0]  # [B, 1]
    
    # 添加margin
    z_min = z_min - margin
    z_max = z_max + margin
    
    # 创建mask
    z_coords = points[..., 2]  # [B, N]
    mask = (z_coords >= z_min) & (z_coords <= z_max)  # [B, N]
    
    # 应用mask
    filtered_points = points[mask].view(points.size(0), -1, 3)  # [B, N', 3]
    
    return filtered_points, mask
