
import torch
import glob
import os

# 定义输出目录
output_dir = ""

# 获取所有 .pt 文件
pt_files = glob.glob(os.path.join(output_dir, "*.pt"))

# 检查是否有文件
if not pt_files:
    print("未找到任何 .pt 文件！请检查输出目录：", output_dir)
else:
    # 只取第一个 .pt 文件（这里使用 pt_files[2]，可根据需要调整）
    pt_file = pt_files[0]
    print(f"检查 .pt 文件: {pt_file}")
    
    try:
        # 加载 .pt 文件
        data = torch.load(pt_file)
        
        # 提取数据
        protein_id = data.get('protein_id', '未知')
        points = data.get('points')
        normals = data.get('normals')
        features = data.get('features')
        iface_labels = data.get('iface_labels')
        probe_radius = data.get('probe_radius', '未知')
        delta_g = data.get('delta_g')
        
        # 打印详细信息
        print(f"\n文件: {pt_file}")
        print(f"蛋白质 ID: {protein_id}")
        print(f"探针半径: {probe_radius}")
        print(f"点坐标形状: {points.shape if points is not None else '无'}")
        print(f"法向量形状: {normals.shape if normals is not None else '无'}")
        print(f"特征形状: {features.shape if features is not None else '无'}")
        print(f"界面标签形状: {iface_labels.shape if iface_labels is not None else '无'}")
        
        # 打印 delta_g
        print(f"ΔG (kcal/mol): {delta_g.item():.2f} {'(有效)' if not torch.isnan(delta_g) else '(NaN，无效)'}")
        
        # 打印数据统计信息
        if points is not None:
            print(f"点坐标范围: min={points.min().item():.2f}, max={points.max().item():.2f}")
        if normals is not None:
            print(f"法向量模长: min={normals.norm(dim=1).min().item():.2f}, max={normals.norm(dim=1).max().item():.2f}")
        if features is not None:
            print(f"化学特征范围: min={features[:, :3].min().item():.2f}, max={features[:, :3].max().item():.2f}")
            print(f"原子类型特征范围: min={features[:, 3:9].min().item():.2f}, max={features[:, 3:9].max().item():.2f}")
            print(f"曲率特征范围: min={features[:, 9:].min().item():.2f}, max={features[:, 9:].max().item():.2f}")
        if iface_labels is not None:
            print(f"界面标签正例比例: {iface_labels.mean().item():.3f}")
            print(f"界面标签值范围: min={iface_labels.min().item():.2f}, max={iface_labels.max().item():.2f}")
            
    except Exception as e:
        print(f"错误：无法加载 {pt_file}: {e}")