import numpy as np
import torch

def check_tensor_stats(tensor: torch.Tensor, tensor_name: str = "tensor", critical_step: str = "") -> None:
    """
    检查张量中是否存在NaN和Inf值，若存在则立即中断代码执行，并输出详细信息
    
    参数:
        tensor: 要检查的PyTorch张量
        tensor_name: 张量的名称，用于输出信息
        critical_step: 当前执行的步骤描述，帮助定位问题发生的阶段
    """
    # 检查并统计NaN和Inf
    nan_mask = torch.isnan(tensor)
    inf_mask = torch.isinf(tensor)
    has_nan = nan_mask.any()
    has_inf = inf_mask.any()
    nan_count = nan_mask.sum().item() if has_nan else 0
    inf_count = inf_mask.sum().item() if has_inf else 0
    
    # 计算有效元素数量
    valid_mask = ~nan_mask & ~inf_mask
    valid_count = valid_mask.sum().item()
    total_count = tensor.numel()
    
    # 输出基本信息
    step_info = f"（步骤：{critical_step}）" if critical_step else ""
    # print(f"🔍 检查变量 {tensor_name} {step_info}")
    
    # 如果存在NaN或Inf，中断代码
    if has_nan or has_inf:
        error_msg = (
            f"❌ 变量 {tensor_name} 检测到无效值 {step_info}:\n"
            f"   NaN数量: {nan_count}\n"
            f"   Inf数量: {inf_count}\n"
            f"   有效元素: {valid_count}/{total_count} ({valid_count/total_count*100:.2f}%)"
        )
        print(error_msg)
        # 抛出异常中断执行
        raise ValueError(error_msg)
    else:
        # 无无效值时输出统计信息
        min_val = tensor.min().item()
        max_val = tensor.max().item()
        if max_val>1000000 or min_val<-1000000:
            print(f"✅ 变量 {tensor_name} 无无效值\n"
                f"   最小值: {min_val:.6f}\n"
                f"   最大值: {max_val:.6f}\n"
                f"   元素总数: {total_count}")

            print("----------------------------------------")
