import os
import torch
import open3d as o3d
from tqdm import tqdm

import os, csv
import numpy as np
import open3d as o3d

import csv
import json

def transform_pcd(pcd):
    """对点云做坐标变换
    pcd: o3d.geometry.PointCloud
    T: 4x4 numpy array, 齐次变换矩阵
    """
    T = np.array([
        [1, 0, 0, 0],   # x' = 1*x
        [0, 0, -1, 0],  # y' = -z
        [0, 1, 0, 1.5],   # z' = y
        [0, 0, 0, 1]
    ])
    pcd.transform(T)
    return pcd

def format_sig(x, sig=3):
    """保留三位有效数字（支持 int/float/str）"""
    try:
        f = float(x)
        return float(f"{f:.{sig}g}")
    except (ValueError, TypeError):
        return x

def csv_to_json(csv_path, json_path, sig=3):
    data = {}
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            traj = row.get("traj", None)
            if traj is None:
                continue
            metrics = {}
            for k, v in row.items():
                if k == "traj":
                    continue
                metrics[k] = format_sig(v, sig)
            data[traj] = metrics

    with open(json_path, "w") as f:
        json.dump(data, f, indent=2)

    print(f"Saved JSON to {json_path}")

if __name__ == "__main__":
    csv_to_json("eval_results.csv", "eval_results.json", sig=3)
exit(0)

# ========== 核心：评测两个点云（pred vs gt） ==========
def evaluate_pointcloud(
    pcd_pred_raw,
    pcd_gt_raw,
    v_eval,
    taus,
    do_downsample,
    denoise,      # 'stat' | 'radius' | None
    estimate_normals,     # 计算法向以评估法向一致性
    compute_voxel_iou,   # 是否计算体素IoU（稍慢）
):
    """
    返回一个 dict，包含论文中常报的指标：
      - ICP_fitness / inlier_rmse
      - Chamfer_L1 (m)
      - RMSE_pred_to_gt / RMSE_gt_to_pred (cm)
      - Precision/Recall/F1 @ 每个 tau
      - Normal median/angular error （若估计法向）
      - Voxel IoU（可选）
    """
    if taus is None:
        taus = [v_eval, 2*v_eval]

    # 0) 复制（避免修改原对象）
    p_pred = o3d.geometry.PointCloud(pcd_pred_raw)
    p_gt   = o3d.geometry.PointCloud(pcd_gt_raw)

    # 1) 统一体素 & 轻量去噪（两边对称、可选）
    def _prep(pcd: o3d.geometry.PointCloud):
        q = pcd.voxel_down_sample(v_eval) if do_downsample else o3d.geometry.PointCloud(pcd)
        if denoise == "stat":
            q, _ = q.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
        elif denoise == "radius":
            q, _ = q.remove_radius_outlier(nb_points=3, radius=2.5*v_eval)
        return q

    p_pred = _prep(p_pred)
    p_gt   = _prep(p_gt)

    # 2) （可选）估计法向 —— 用于法向一致性
    if estimate_normals:
        for pc in (p_pred, p_gt):
            # radius 取 ~3*v_eval 更稳，max_nn 控制上限
            pc.estimate_normals(
                search_param=o3d.geometry.KDTreeSearchParamHybrid(
                    radius=3.0*v_eval, max_nn=30
                )
            )
            pc.orient_normals_consistent_tangent_plane(30)

    # 3) 刚体配准（pred → gt），仅作为对齐（不改变几何）
    reg = o3d.pipelines.registration.registration_icp(
        p_pred, p_gt, max_correspondence_distance=5.0*v_eval,
        estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPlane()
    )
    p_pred.transform(reg.transformation)

    # 4) 最近邻距离（双向）
    def _nn_dists(A: o3d.geometry.PointCloud, B: o3d.geometry.PointCloud):
        kd = o3d.geometry.KDTreeFlann(B)
        A_np = np.asarray(A.points)
        d = np.empty(len(A_np), dtype=np.float32)
        for i, x in enumerate(A_np):
            _, idx, dist2 = kd.search_knn_vector_3d(x, 1)
            d[i] = np.sqrt(dist2[0]) if len(idx) else np.inf
        return d

    d_pg = _nn_dists(p_pred, p_gt)  # pred -> gt
    d_gp = _nn_dists(p_gt,   p_pred)  # gt -> pred

    # 5) 几何指标
    chamfer_L1 = float(d_pg.mean() + d_gp.mean())           # m
    rmse_pg_cm = float(np.sqrt((d_pg**2).mean()) * 100.0)    # cm
    rmse_gp_cm = float(np.sqrt((d_gp**2).mean()) * 100.0)    # cm

    metrics = {
        "n_pred":        int(len(p_pred.points)),
        "n_gt":          int(len(p_gt.points)),
        "v_eval_m":      float(v_eval),
        "ICP_fitness":   float(reg.fitness),
        "ICP_inlier_rmse_m": float(reg.inlier_rmse),
        "Chamfer_L1_m":  chamfer_L1,
        "RMSE_pred2gt_cm": rmse_pg_cm,
        "RMSE_gt2pred_cm": rmse_gp_cm,
    }

    # 6) Precision / Recall / F-score @ taus
    for tau in taus:
        P = float((d_pg <= tau).mean())
        R = float((d_gp <= tau).mean())
        F = 2.0 * P * R / (P + R + 1e-12)
        tag = f"@{tau*100:.1f}cm"
        metrics[f"Precision{tag}"] = P
        metrics[f"Recall{tag}"]    = R
        metrics[f"F1{tag}"]        = F

    # 7) 法向一致性（可选）：只统计几何内点（d<=tau）上的角度（度）
    if estimate_normals and len(p_pred.normals) and len(p_gt.normals):
        # 以主阈值（第一个 tau）对应
        tau0 = taus[0]
        # pred->gt 对应
        kd_gt = o3d.geometry.KDTreeFlann(p_gt)
        normals_pred = np.asarray(p_pred.normals)
        normals_gt   = np.asarray(p_gt.normals)
        ang_list = []
        for i, x in enumerate(np.asarray(p_pred.points)):
            _, idx, dist2 = kd_gt.search_knn_vector_3d(x, 1)
            if len(idx) and dist2[0] <= tau0**2:
                n1 = normals_pred[i]
                n2 = normals_gt[idx[0]]
                # 夹角（0~180）
                cosv = np.clip(np.dot(n1, n2) / (np.linalg.norm(n1)*np.linalg.norm(n2)+1e-12), -1.0, 1.0)
                ang = np.degrees(np.arccos(cosv))
                # 双向法向模糊，取 min(angle, 180-angle)
                ang = min(ang, 180.0 - ang)
                ang_list.append(ang)
        if len(ang_list):
            ang_arr = np.array(ang_list, dtype=np.float32)
            metrics["Normal_ang_median_deg"] = float(np.median(ang_arr))
            metrics["Normal_ang_mean_deg"]   = float(ang_arr.mean())
            metrics["Normal_samples"]        = int(len(ang_arr))

    # 8) 体素 IoU（可选；用同一 v_eval 网格，锚点在原点）
    if compute_voxel_iou:
        def _voxel_keys(pts: np.ndarray, v: float):
            # 将点量化到体素索引（原点对齐）
            return np.floor(pts / v).astype(np.int32).view([('x','<i4'),('y','<i4'),('z','<i4')]).ravel()
        A = np.asarray(p_pred.points)
        B = np.asarray(p_gt.points)
        a_keys = _voxel_keys(A, v_eval)
        b_keys = _voxel_keys(B, v_eval)
        a_set = set(a_keys.tolist())
        b_set = set(b_keys.tolist())
        inter = len(a_set & b_set)
        union = len(a_set | b_set)
        metrics["Voxel_IoU"] = float(inter / (union + 1e-12))
        metrics["Voxel_intersect"] = int(inter)
        metrics["Voxel_union"]     = int(union)

    return metrics

# ========== 小工具：把评测结果写到 CSV ==========
def write_metrics_csv(csv_path: str, rows: list[dict]):
    if not rows:
        return
    keys = sorted({k for r in rows for k in r.keys()})
    with open(csv_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        for r in rows:
            w.writerow(r)

# load ground-truth pcd
gt_path = "../openscene/data/matterport_3d/test"
gt_files = [f for f in os.listdir(gt_path) if 'UwV83HsGsw3' in f]
gt_pcd = o3d.geometry.PointCloud()
for file in gt_files:
    assert file.endswith('.pth')
    points, colors, instance_labels = torch.load(os.path.join(gt_path, file), map_location='cpu', weights_only=False)
    p = points
    mask = np.isfinite(p).all(axis=1)
    p = p[mask]
    # 去重复（可选）
    if len(p) > 0:
        p = np.unique(p, axis=0)

    assert len(p) > 100, f"点太少：{len(p)}"

    # Open3D point cloud
    gt_pcd_part = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(p))
    gt_pcd += gt_pcd_part
    print(f"Load gt part pcd from {file}, #points: {len(gt_pcd_part.points)}")
print(f"Load gt pcd from {gt_path}, #points: {len(gt_pcd.points)}")

v_eval = 0.014
all_rows = []
traj_dir = "data/UwV_walks"
trajs = os.listdir(traj_dir)
for traj in tqdm(trajs):
    pcd_path = os.path.join(traj_dir, traj, 'UwV83HsGsw3', 'pcd.ply')
    pcd = o3d.io.read_point_cloud(pcd_path)
    pcd = transform_pcd(pcd)
    print(f"Load pcd from {pcd_path}, points: {len(pcd.points)}")

    # 评测（可按需打开 compute_voxel_iou / estimate_normals）
    res = evaluate_pointcloud(
        pcd_pred_raw=pcd,
        pcd_gt_raw=gt_pcd,
        v_eval=v_eval,
        taus=[v_eval, 2*v_eval],
        do_downsample=True,          # 预测和GT都会在函数内部用同一 v_eval 再统一一次
        denoise="stat",
        estimate_normals=True,
        compute_voxel_iou=True
    )
    res["traj"] = traj
    all_rows.append(res)

# 3) 保存汇总结果
write_metrics_csv("eval_results.csv", all_rows)
print("Saved metrics to eval_results.csv")