"""
Fast composition/count checker + XY‑lattice RMSD
================================================
判定 (qualified):
    • 元素集合完全一致   (composition_similarity = 1.0)
    • 元素计数完全一致   (count_similarity       = 1.0)

同时计算:
    xy_lattice_rmsd = sqrt( (aR‑aL)^2 + (bR‑bL)^2 ) / sqrt(2)
"""

import os, re, glob, json, argparse, math
from pathlib import Path
from collections import Counter

# ---------- CLI ----------
parser = argparse.ArgumentParser()
parser.add_argument("--recon-dirs", nargs="+",
                    required=True, help="重建 CIF 目录")
parser.add_argument("--label-dir",
                    required=True, help="真值 CIF 目录")
parser.add_argument("--output-json", default="atomai_comp_count_rmsd_results.json",
                    help="输出 JSON 文件")
args = parser.parse_args()

# ---------- 正则 ----------
id_pat     = re.compile(r"\w+-\d+")        # 提取 2dm‑123 之类
sym_line   = re.compile(r"^\s*([A-Za-z]{1,2})\s+", re.M)
cell_a_pat = re.compile(r"_cell_length_a\s+([\d\.Ee+\-]+)")
cell_b_pat = re.compile(r"_cell_length_b\s+([\d\.Ee+\-]+)")
noise_pat  = re.compile(r"(dose|reconstructed|tmp|iDPC|V\d+)", re.I)


def extract_id(fname: str):
    m = id_pat.search(fname)
    return m.group(0).split('_')[-1] if m else None


# ---------- CIF 轻量解析 (缓存) ----------
_cache: dict[str, tuple[list[str], float, float]] = {}


def parse_cif(path: str):
    """
    返回 (symbols[], a, b)
    a,b 若缺失则为 0.0
    """
    if path in _cache:
        return _cache[path]
    with open(path, 'r', errors='ignore') as fh:
        txt = fh.read()
    symbols = [m.group(1) for m in sym_line.finditer(txt)]
    a = float(cell_a_pat.search(txt).group(1)) if cell_a_pat.search(txt) else 0.0
    b = float(cell_b_pat.search(txt).group(1)) if cell_b_pat.search(txt) else 0.0
    _cache[path] = (symbols, a, b)
    return _cache[path]


# ---------- label 映射 ----------
label_map: dict[str, str] = {}
for p in glob.glob(os.path.join(args.label_dir, "*.cif")):
    if noise_pat.search(p):
        continue
    mid = extract_id(os.path.basename(p))
    if mid:
        label_map[mid] = p

# ---------- 主循环 ----------
total = paired = qualified = 0
records = []
xy_rmsd_list = []          # 新增：用来收集所有 rmsd

for rd in args.recon_dirs:
    for r_path in glob.glob(os.path.join(rd, "*.cif")):
        total += 1
        mid = extract_id(os.path.basename(r_path))
        if not mid:
            continue
        l_path = label_map.get(mid)
        if not l_path:
            print(f"[skip] 无 label: {mid}")
            continue

        sym_R, a_R, b_R = parse_cif(r_path)
        sym_L, a_L, b_L = parse_cif(l_path)

        comp_ok  = set(sym_R) == set(sym_L)
        count_ok = Counter(sym_R) == Counter(sym_L)
        xy_rmsd  = math.sqrt((a_R - a_L)**2 + (b_R - b_L)**2) / math.sqrt(2)

        ok = comp_ok and count_ok
        paired += 1
        if ok:
            qualified += 1
            xy_rmsd_list.append(xy_rmsd)         # 仅收集合格 RMSD


        print(f"{mid:<15} comp={int(comp_ok)} count={int(count_ok)} "
              f"RMSD={xy_rmsd:.4f} Å OK={ok}")

        records.append({
            "reconstructed_cif": str(Path(r_path).resolve()),
            "label_cif":         str(Path(l_path).resolve()),
            "composition_similarity": 1.0 if comp_ok  else 0.0,
            "count_similarity":       1.0 if count_ok else 0.0,
            "xy_lattice_rmsd":        round(xy_rmsd, 4),
            "qualified":              ok
        })

# ---------- 汇总 ----------
print("\n========== 统计 ==========")
print(f"总重建文件 : {total}")
print(f"成功配对   : {paired}")
if paired:
    print(f"Qualified  : {qualified}/{paired} ({qualified/paired*100:.1f}%)")
    # ★ 新增：输出均值 ± 标准差
    import numpy as np
    mean, std = np.mean(xy_rmsd_list), np.std(xy_rmsd_list)
    print(f"XY‑lattice RMSD   : {mean:.2f} ± {std:.2f} Å")

# ---------- JSON ----------
with open(args.output_json, "w") as fw:
    json.dump(records, fw, indent=2)
print(f"\n✓ 结果已写入 {args.output_json}")
