import pandas as pd
import numpy as np
import json
import argparse

# ---------- CLI ----------
parser = argparse.ArgumentParser()
parser.add_argument("--csv", default='tier_3_results.csv', help="输入CSV文件路径")
parser.add_argument("--output-json", default="energy_metrics_by_tier.json", help="输出JSON文件路径")
args = parser.parse_args()

# ---------- 读取数据 ----------
df = pd.read_csv(args.csv)

# ---------- 初始化输出结构 ----------
metrics_by_tier = {}

# ---------- 分组计算 ----------
for tier, group in df.groupby("tier"):
    energy_error = group["energy_error"]
    energy_per_atom_error = group["energy_per_atom_error"]

    metrics = {
        "Energy MSE": float(np.mean(energy_error ** 2)),
        "Energy MAE": float(np.mean(np.abs(energy_error))),
        "Energy per Atom MSE": float(np.mean(energy_per_atom_error ** 2)),
        "Energy per Atom MAE": float(np.mean(np.abs(energy_per_atom_error))),
        "Average Energy Error": float(np.mean(energy_error)),
        "Average Energy per Atom Error": float(np.mean(energy_per_atom_error))
    }

    metrics_by_tier[tier] = metrics

# ---------- 保存为JSON ----------
with open(args.output_json, "w") as f:
    json.dump(metrics_by_tier, f, indent=4)

print(f"✓ 成功写入 {args.output_json}")
