"""
XY‑similarity evaluator
=======================
数学定义
--------
**空间相似度** (spatial_similarity)
       对每种元素，最邻距离 dᵢ = ‖rᵢ − r′(i)‖
       MSE₂D = Σᵢ dᵢ² / N，  similarity = exp(−MSE₂D)

依赖
----
pip install pymatgen ase scipy numpy
"""

import os, re, glob, json, argparse
import numpy as np
from collections import Counter
from scipy.spatial import cKDTree
from pathlib import Path
# ----------- ASE 结构读入 -----------
from ase.io import read


# ---------------- CLI ----------------
parser = argparse.ArgumentParser()
parser.add_argument("--recon-dirs", nargs="+", default=None,
                    help="重建 CIF 文件夹(可多)")
parser.add_argument("--label-dir", default=None,
                    help="真值 CIF 文件夹")
parser.add_argument("--output-json", default="xy_similarity_pairs.json",
                    help="输出 JSON 文件名")
parser.add_argument("--sim-threshold", type=float, default=0.8,
                    help="overall_similarity ≥ 阈值视为匹配 (统计用)")
args = parser.parse_args()

# ----------- 正则、一点工具 -----------
mid_pat   = re.compile(r"(\w+-\d+)")
noise_pat = re.compile(r"(dose|reconstructed|tmp|iDPC|V\d+)", re.I)


def find_label(mid, label_dir):
    """在 label_dir 中根据 material_id 找真值 CIF"""
    for p in glob.glob(os.path.join(label_dir, f"*{mid}*.cif")):
        if not noise_pat.search(os.path.basename(p)):
            return p
    return None


# ================= 相似度函数 =================
def compare_lattice_2d(atoms1, atoms2, length_tol=0.05, angle_tol=5):
    """
    判断二维材料结构晶格是否相似（忽略c轴长度和z轴角度）。
    """
    cell1 = atoms1.cell.lengths()
    cell2 = atoms2.cell.lengths()
    angles1 = atoms1.cell.angles()
    angles2 = atoms2.cell.angles()

    # 只比较 a, b 和 α, β
    ab1, ab2 = cell1[:2], cell2[:2]
    rel_diff = np.abs((ab1 - ab2) / ab1)
    if np.any(rel_diff > length_tol):
        return False

    angle_diff = np.abs(np.array(angles1[:2]) - np.array(angles2[:2]))
    if np.any(angle_diff > angle_tol):
        return False

    return True


def compare_structure_fractional_2d(cif1_path, cif2_path, dist_thresh=0.2):
    """
    比较两个 CIF 结构是否在 2D 结构下相似（考虑 lattice + 元素分组相对坐标）。
    """
    atoms1 = read(cif1_path)
    atoms2 = read(cif2_path)

    # 1. 晶格判断（忽略c轴）
    if not compare_lattice_2d(atoms1, atoms2):
        print("Lattice mismatch.")
        return 0.0

    # 2. 化学组分判断
    symbols1 = atoms1.get_chemical_symbols()
    symbols2 = atoms2.get_chemical_symbols()
    unique1 = set(symbols1)
    unique2 = set(symbols2)
    if unique1 != unique2:
        print("Element mismatch.")
        return 0.0

    # 3. 获取归一化坐标并忽略z轴
    frac1 = atoms1.get_scaled_positions()[:, :2]
    frac2 = atoms2.get_scaled_positions()[:, :2]

    # 4. 元素分组后匹配坐标
    total_mse = 0.0
    total_atoms = 0

    for element in unique1:
        idx1 = [i for i, s in enumerate(symbols1) if s == element]
        idx2 = [i for i, s in enumerate(symbols2) if s == element]

        if len(idx1) != len(idx2):
            print(f"Atom count mismatch for element {element}")
            return 0.0

        pos1 = frac1[idx1]
        pos2 = frac2[idx2]

        tree = cKDTree(pos2)
        dists, _ = tree.query(pos1, k=1)

        if np.any(dists > dist_thresh):
            print(f"Distance too large for element {element}")
            return 0.0

        total_mse += np.sum(dists ** 2)
        total_atoms += len(pos1)

    avg_mse = total_mse / total_atoms
    similarity_score = np.exp(-avg_mse)
    return similarity_score


def overall_similarity(cif1, cif2):
    spatial_sim = compare_structure_fractional_2d(cif1, cif2)
    return {
        "spatial_similarity": round(spatial_sim, 4),
    }


# ========================== 主循环 ==========================
total = paired = 0
stats_sim_ok = 0
overall_list = []
records = []

for recon_dir in args.recon_dirs:
    for recon_path in glob.glob(os.path.join(recon_dir, "*.cif")):
        total += 1
        name = os.path.basename(recon_path)
        m = mid_pat.search(name)
        if not m:
            print(f"[skip] 无 material_id : {name}")
            continue
        mid = m.group(1)
        label_path = find_label(mid, args.label_dir)
        if not label_path:
            print(f"[skip] 无 label: {mid}")
            continue

        try:
            atoms_R = read(recon_path)
            atoms_L = read(label_path)
        except Exception as e:
            print(f"[skip] 解析失败 {mid} : {e}")
            continue

        sim_dict = overall_similarity(recon_path, label_path)

        paired += 1
        overall_list.append(sim_dict["spatial_similarity"])
        if sim_dict["spatial_similarity"] >= args.sim_threshold:
            stats_sim_ok += 1

        print(f"{mid:<15}"
              f"OverallSim={sim_dict['spatial_similarity']:.3f}")

        records.append({
            "reconstructed_cif": str(Path(recon_path).resolve()),
            "label_cif":         str(Path(label_path).resolve()),
            **sim_dict
        })

# ---------------- 汇总 ----------------
print("\n========== 统计 ==========")
print(f"总重建文件        : {total}")
print(f"成功配对          : {paired}")
if paired:
    print(f"OverallSim 均值   : {np.mean(overall_list):.3f}")
    print(f"OverallSim ≥ {args.sim_threshold} : {stats_sim_ok}/{paired} "
          f"({stats_sim_ok/paired:.1%})")
else:
    print("无有效配对。")

# ---------------- JSON 输出 ----------------
with open(args.output_json, "w") as fw:
    json.dump(records, fw, indent=4)
print(f"\n✧ 详细结果已写入 {args.output_json}")
