"""
Controlled cross-scenario validation for NTS vs RTD/SRTD
========================================================

This script is designed for the reviewer request:
"validate NTS as a cross-scenario normalized metric using a controlled setting
with ground truth, and compare it against RTD/SRTD."

Included metrics:
    - NTS-E
    - NTS-M
    - RTD      full persistent-homology version, symmetrized directional RTD
    - SRTD     full persistent-homology version
    - RTD-lite MST version
    - SRTD-lite MST version

Important dependency for full RTD/SRTD:
    pip install giotto-tda

Other dependencies:
    pip install numpy scipy scikit-learn pandas matplotlib

Recommended run:
    python nts_rtd_srtd_controlled_validation.py \
        --out_dir results_nts_rtd_srtd \
        --n_repeats 10

Full RTD/SRTD are always computed.
"""

from __future__ import annotations

import argparse
import os
from dataclasses import asdict, dataclass
from typing import Dict, List, Sequence, Tuple


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.stats import kendalltau, pearsonr, spearmanr
from sklearn.metrics import adjusted_rand_score, pairwise_distances

try:
    from gtda.homology import VietorisRipsPersistence
    GTDA_AVAILABLE = True
except Exception:
    GTDA_AVAILABLE = False


# ---------------------------------------------------------------------
# Scenario configuration
# ---------------------------------------------------------------------

@dataclass(frozen=True)
class Scenario:
    name: str
    samples_per_leaf: int = 8
    signal_dim: int = 12
    noise_dims: int = 0
    sigma: float = 0.20
    imbalance: float = 0.0


# Keep sample sizes moderate because full RTD/SRTD is expensive.
SCENARIOS: List[Scenario] = [
    Scenario("baseline", samples_per_leaf=8, signal_dim=12, noise_dims=0, sigma=0.20, imbalance=0.0),
    Scenario("small_n", samples_per_leaf=5, signal_dim=12, noise_dims=0, sigma=0.20, imbalance=0.0),
    Scenario("large_n", samples_per_leaf=12, signal_dim=12, noise_dims=0, sigma=0.20, imbalance=0.0),
    Scenario("high_noise", samples_per_leaf=8, signal_dim=12, noise_dims=0, sigma=0.50, imbalance=0.0),
    Scenario("high_dim", samples_per_leaf=8, signal_dim=32, noise_dims=0, sigma=0.20, imbalance=0.0),
    Scenario("noise_dims", samples_per_leaf=8, signal_dim=12, noise_dims=24, sigma=0.20, imbalance=0.0),
    Scenario("imbalanced", samples_per_leaf=8, signal_dim=12, noise_dims=0, sigma=0.20, imbalance=0.70),
]


# ---------------------------------------------------------------------
# Planted hierarchy benchmark
# ---------------------------------------------------------------------

def make_sample_leaf_ids(
    n_leaves: int,
    samples_per_leaf: int,
    imbalance: float,
    rng: np.random.Generator,
) -> np.ndarray:
    """Create sample-level leaf assignments."""
    if imbalance <= 1e-12:
        counts = np.full(n_leaves, samples_per_leaf, dtype=int)
    else:
        total_n = n_leaves * samples_per_leaf
        weights = rng.lognormal(mean=0.0, sigma=imbalance, size=n_leaves)
        probs = weights / weights.sum()
        counts = rng.multinomial(max(total_n - 2 * n_leaves, 0), probs) + 2

    leaf_ids = np.concatenate([np.full(c, leaf, dtype=int) for leaf, c in enumerate(counts)])
    rng.shuffle(leaf_ids)
    return leaf_ids


def hierarchy_partitions(leaf_positions: np.ndarray, depth: int) -> List[np.ndarray]:
    """Return hierarchy labels at levels 1..depth."""
    return [(leaf_positions >> (depth - level)).astype(int) for level in range(1, depth + 1)]


def ground_truth_hierarchy_distance(
    base_positions: np.ndarray,
    perturbed_positions: np.ndarray,
    depth: int,
) -> float:
    """
    Ground-truth hierarchy distance.

    This is independent of NTS, RTD, and SRTD:
        GTDistance = 1 - mean_l ARI(partition_l(base), partition_l(perturbed)).
    """
    base_parts = hierarchy_partitions(base_positions, depth)
    pert_parts = hierarchy_partitions(perturbed_positions, depth)
    ari_values = [adjusted_rand_score(a, b) for a, b in zip(base_parts, pert_parts)]
    return float(np.clip(1.0 - np.mean(ari_values), 0.0, 1.0))


def corrupt_leaf_positions(
    leaf_ids: np.ndarray,
    n_leaves: int,
    perturb_frac: float,
    depth: int,
    rng: np.random.Generator,
    cross_top_group: bool = True,
) -> np.ndarray:
    """
    Sample-level hierarchy corruption.

    A controlled fraction of samples is reassigned to another hierarchy leaf.
    With cross_top_group=True, the new leaf is chosen from the opposite top-level
    group, which creates a clearer hierarchy-level structural perturbation.
    """
    perturbed = leaf_ids.copy()
    n = len(leaf_ids)
    m = int(round(perturb_frac * n))
    if m == 0:
        return perturbed

    changed_idx = rng.choice(n, size=m, replace=False)
    for idx in changed_idx:
        current_leaf = int(leaf_ids[idx])
        if cross_top_group:
            top = current_leaf >> (depth - 1)
            candidates = [x for x in range(n_leaves) if (x >> (depth - 1)) != top]
        else:
            candidates = [x for x in range(n_leaves) if x != current_leaf]
        perturbed[idx] = int(rng.choice(candidates))
    return perturbed


def generate_hierarchy_centers(
    n_leaves: int,
    depth: int,
    signal_dim: int,
    rng: np.random.Generator,
    level_scales: Sequence[float] | None = None,
) -> np.ndarray:
    """Generate leaf centers from a planted hierarchy."""
    if level_scales is None:
        level_scales = [8.0, 4.0, 2.0, 1.0][:depth]

    prototypes: Dict[Tuple[int, int], np.ndarray] = {}
    for level in range(1, depth + 1):
        for group_id in range(2 ** level):
            v = rng.normal(size=signal_dim)
            v /= np.linalg.norm(v) + 1e-12
            prototypes[(level, group_id)] = v

    centers = np.zeros((n_leaves, signal_dim), dtype=float)
    for leaf in range(n_leaves):
        c = np.zeros(signal_dim, dtype=float)
        for level in range(1, depth + 1):
            group_id = leaf >> (depth - level)
            c += level_scales[level - 1] * prototypes[(level, group_id)]
        centers[leaf] = c
    return centers


def generate_points_with_shared_noise(
    leaf_positions: np.ndarray,
    centers: np.ndarray,
    sigma: float,
    noise_dims: int,
    eps_signal: np.ndarray,
    eps_noise: np.ndarray | None,
) -> np.ndarray:
    """
    Generate representation with shared per-sample noise.

    This makes perturb_frac=0 exactly identical between X0 and Xq:
        D_NTS = RTD = SRTD = RTD-lite = SRTD-lite = 0.
    """
    X = centers[leaf_positions] + sigma * eps_signal
    if noise_dims > 0:
        if eps_noise is None:
            raise ValueError("eps_noise is required when noise_dims > 0")
        X = np.concatenate([X, eps_noise], axis=1)
    return X.astype(float)


# ---------------------------------------------------------------------
# Basic distance / MST utilities
# ---------------------------------------------------------------------

def upper_triangular_values(D: np.ndarray) -> np.ndarray:
    return D[np.triu_indices_from(D, k=1)]


def normalize_by_quantile(D: np.ndarray, q: float = 0.90) -> np.ndarray:
    vals = upper_triangular_values(D)
    vals = vals[vals > 0]
    scale = float(np.quantile(vals, q)) if vals.size > 0 else 1.0
    return D.copy() if scale <= 1e-12 else D / scale


def mst_edges_and_weight(D: np.ndarray) -> Tuple[List[Tuple[int, int]], np.ndarray, float]:
    mst = minimum_spanning_tree(D).tocoo()
    edges: List[Tuple[int, int]] = []
    weights: List[float] = []
    for i, j, w in zip(mst.row, mst.col, mst.data):
        a, b = int(i), int(j)
        if a > b:
            a, b = b, a
        edges.append((a, b))
        weights.append(float(w))
    weights_arr = np.asarray(weights, dtype=float)
    return edges, weights_arr, float(weights_arr.sum())


def safe_spearman(x: np.ndarray, y: np.ndarray) -> float:
    if len(x) < 2 or np.std(x) <= 1e-12 or np.std(y) <= 1e-12:
        return np.nan
    return float(spearmanr(x, y).statistic)


def tree_adjacency(n: int, edges: Sequence[Tuple[int, int]], weights: Sequence[float]) -> List[List[Tuple[int, float]]]:
    adj: List[List[Tuple[int, float]]] = [[] for _ in range(n)]
    for (a, b), w in zip(edges, weights):
        adj[a].append((b, float(w)))
        adj[b].append((a, float(w)))
    return adj


def bottleneck_path_value(adj: List[List[Tuple[int, float]]], source: int, target: int) -> float:
    stack = [(source, -1, 0.0)]
    while stack:
        node, parent, current_max = stack.pop()
        if node == target:
            return current_max
        for nxt, w in adj[node]:
            if nxt != parent:
                stack.append((nxt, node, max(current_max, w)))
    raise RuntimeError("Tree path not found")


def bottleneck_values(
    n: int,
    edges: Sequence[Tuple[int, int]],
    weights: Sequence[float],
    pairs: Sequence[Tuple[int, int]],
) -> np.ndarray:
    adj = tree_adjacency(n, edges, weights)
    return np.asarray([bottleneck_path_value(adj, a, b) for a, b in pairs], dtype=float)


# ---------------------------------------------------------------------
# NTS-E / NTS-M
# ---------------------------------------------------------------------

def nts_metrics(D1: np.ndarray, D2: np.ndarray) -> Dict[str, float]:
    n = D1.shape[0]
    edges1, weights1, _ = mst_edges_and_weight(D1)
    edges2, weights2, _ = mst_edges_and_weight(D2)
    core_edges = sorted(set(edges1).union(set(edges2)))

    x_e = np.asarray([D1[a, b] for a, b in core_edges], dtype=float)
    y_e = np.asarray([D2[a, b] for a, b in core_edges], dtype=float)
    x_m = bottleneck_values(n, edges1, weights1, core_edges)
    y_m = bottleneck_values(n, edges2, weights2, core_edges)

    nts_e = safe_spearman(x_e, y_e)
    nts_m = safe_spearman(x_m, y_m)
    return {
        "NTS_E": nts_e,
        "NTS_M": nts_m,
        "D_NTS_E": (1.0 - nts_e) / 2.0 if np.isfinite(nts_e) else np.nan,
        "D_NTS_M": (1.0 - nts_m) / 2.0 if np.isfinite(nts_m) else np.nan,
    }


# ---------------------------------------------------------------------
# RTD-lite / SRTD-lite
# ---------------------------------------------------------------------

def rtd_lite_metrics(D1: np.ndarray, D2: np.ndarray, q: float = 0.90) -> Dict[str, float]:
    W1 = normalize_by_quantile(D1, q)
    W2 = normalize_by_quantile(D2, q)
    Wmin = np.minimum(W1, W2)
    Wmax = np.maximum(W1, W2)

    _, _, mst1 = mst_edges_and_weight(W1)
    _, _, mst2 = mst_edges_and_weight(W2)
    _, _, mst_min = mst_edges_and_weight(Wmin)
    _, _, mst_max = mst_edges_and_weight(Wmax)

    rtd_12 = mst1 - mst_min
    rtd_21 = mst2 - mst_min
    rtd_sym = (rtd_12 + rtd_21) / 2.0
    srtd_lite = mst_max - mst_min

    return {
        "RTD_lite": float(rtd_sym),
        "RTD_lite_dir_12": float(rtd_12),
        "RTD_lite_dir_21": float(rtd_21),
        "SRTD_lite": float(srtd_lite),
    }


# ---------------------------------------------------------------------
# Full RTD / SRTD using auxiliary matrices
# ---------------------------------------------------------------------

def plus_matrix(D: np.ndarray, inf_value: float) -> np.ndarray:
    """
    M+ is obtained by replacing the strict upper triangular part by infinity.
    """
    P = D.copy()
    iu = np.triu_indices_from(P, k=1)
    P[iu] = inf_value
    return P


def sanitize_auxiliary_matrix(M: np.ndarray) -> np.ndarray:
    """
    Persistent homology libraries usually expect finite distances.
    We replace np.inf by a large finite value beyond all real filtration values.
    """
    finite_vals = M[np.isfinite(M)]
    max_finite = float(np.max(finite_vals)) if finite_vals.size else 1.0
    large = max(10.0 * max_finite, max_finite + 10.0, 1e6 if max_finite == 0 else 0.0)
    out = M.copy()
    out[~np.isfinite(out)] = large
    np.fill_diagonal(out, 0.0)
    return out


def build_mmin(W: np.ndarray, Wt: np.ndarray) -> np.ndarray:
    """Directional RTD auxiliary matrix Mmin(W, Wt)."""
    n = W.shape[0]
    real_max = max(float(np.max(W)), float(np.max(Wt)), 1.0)
    inf = 10.0 * real_max + 10.0
    Wmin = np.minimum(W, Wt)
    Wplus = plus_matrix(W, inf)
    z = np.zeros((n, 1))
    inf_col = np.full((n, 1), inf)
    top = np.concatenate([W, Wplus.T, z], axis=1)
    mid = np.concatenate([Wplus, Wmin, inf_col], axis=1)
    bot = np.concatenate([z.T, inf_col.T, np.zeros((1, 1))], axis=1)
    return sanitize_auxiliary_matrix(np.concatenate([top, mid, bot], axis=0))


def build_msym(W: np.ndarray, Wt: np.ndarray) -> np.ndarray:
    """Symmetric SRTD auxiliary matrix Msym(W, Wt)."""
    n = W.shape[0]
    real_max = max(float(np.max(W)), float(np.max(Wt)), 1.0)
    inf = 10.0 * real_max + 10.0
    Wmin = np.minimum(W, Wt)
    Wmax = np.maximum(W, Wt)
    Wmax_plus = plus_matrix(Wmax, inf)
    z = np.zeros((n, 1))
    inf_col = np.full((n, 1), inf)
    top = np.concatenate([Wmax, Wmax_plus.T, z], axis=1)
    mid = np.concatenate([Wmax_plus, Wmin, inf_col], axis=1)
    bot = np.concatenate([z.T, inf_col.T, np.zeros((1, 1))], axis=1)
    return sanitize_auxiliary_matrix(np.concatenate([top, mid, bot], axis=0))


def persistence_sum_precomputed(M: np.ndarray, h_dim: int = 1) -> float:
    if not GTDA_AVAILABLE:
        raise ImportError("giotto-tda is required for full RTD/SRTD. Install with: pip install giotto-tda")

    vr = VietorisRipsPersistence(homology_dimensions=[h_dim], metric="precomputed")
    diagrams = vr.fit_transform([M])
    if diagrams.size == 0 or diagrams[0].size == 0:
        return 0.0
    bars = diagrams[0][diagrams[0][:, 2] == h_dim]
    if bars.shape[0] == 0:
        return 0.0
    lengths = bars[:, 1] - bars[:, 0]
    lengths = lengths[np.isfinite(lengths) & (lengths > 1e-12)]
    return float(np.sum(lengths))


def full_rtd_srtd_metrics(D1: np.ndarray, D2: np.ndarray, q: float = 0.90, h_dim: int = 1) -> Dict[str, float]:
    """
    Full RTD and SRTD.

    RTD is the symmetrized average of two directional RTD computations:
        RTD = (RTD(D1,D2) + RTD(D2,D1)) / 2
    SRTD is computed once using Msym.
    """
    W1 = normalize_by_quantile(D1, q)
    W2 = normalize_by_quantile(D2, q)

    mmin_12 = build_mmin(W1, W2)
    mmin_21 = build_mmin(W2, W1)
    msym = build_msym(W1, W2)

    rtd_12 = persistence_sum_precomputed(mmin_12, h_dim=h_dim)
    rtd_21 = persistence_sum_precomputed(mmin_21, h_dim=h_dim)
    rtd = (rtd_12 + rtd_21) / 2.0
    srtd = persistence_sum_precomputed(msym, h_dim=h_dim)

    return {
        "RTD": float(rtd),
        "RTD_dir_12": float(rtd_12),
        "RTD_dir_21": float(rtd_21),
        "SRTD": float(srtd),
    }


def compute_metrics(
    X0: np.ndarray,
    Xq: np.ndarray,
    q: float = 0.90,
    h_dim: int = 1,
) -> Dict[str, float]:
    D0 = pairwise_distances(X0, metric="euclidean")
    Dq = pairwise_distances(Xq, metric="euclidean")

    out: Dict[str, float] = {}
    out.update(nts_metrics(D0, Dq))
    out.update(rtd_lite_metrics(D0, Dq, q=q))
    out.update(full_rtd_srtd_metrics(D0, Dq, q=q, h_dim=h_dim))
    return out


# ---------------------------------------------------------------------
# Experiment runner
# ---------------------------------------------------------------------

def run_experiment(
    out_dir: str,
    n_repeats: int,
    n_leaves: int = 8,
    depth: int = 3,
    perturb_fracs: Sequence[float] = (0.0, 0.05, 0.10, 0.20, 0.30, 0.50, 0.70),
    base_seed: int = 123,
    q: float = 0.90,
    h_dim: int = 1,
) -> pd.DataFrame:
    os.makedirs(out_dir, exist_ok=True)
    rows = []

    if not GTDA_AVAILABLE:
        raise ImportError("Full RTD/SRTD require giotto-tda. Install with: pip install giotto-tda")

    for scenario_id, scenario in enumerate(SCENARIOS):
        for repeat in range(n_repeats):
            seed = base_seed + 10000 * scenario_id + repeat
            rng = np.random.default_rng(seed)

            leaf_ids = make_sample_leaf_ids(n_leaves, scenario.samples_per_leaf, scenario.imbalance, rng)
            centers = generate_hierarchy_centers(n_leaves, depth, scenario.signal_dim, rng)

            eps_signal = rng.normal(size=(len(leaf_ids), scenario.signal_dim))
            eps_noise = rng.normal(size=(len(leaf_ids), scenario.noise_dims)) if scenario.noise_dims > 0 else None
            X0 = generate_points_with_shared_noise(leaf_ids, centers, scenario.sigma, scenario.noise_dims, eps_signal, eps_noise)

            for frac in perturb_fracs:
                pert_rng = np.random.default_rng(seed + 777 + int(round(frac * 10000)))
                perturbed_positions = corrupt_leaf_positions(leaf_ids, n_leaves, frac, depth, pert_rng)
                gt = ground_truth_hierarchy_distance(leaf_ids, perturbed_positions, depth)
                Xq = generate_points_with_shared_noise(perturbed_positions, centers, scenario.sigma, scenario.noise_dims, eps_signal, eps_noise)

                metrics = compute_metrics(X0, Xq, q=q, h_dim=h_dim)
                row = {
                    "scenario": scenario.name,
                    "repeat": repeat,
                    "n_samples": int(len(leaf_ids)),
                    "perturb_frac": float(frac),
                    "gt_hierarchy_distance": float(gt),
                    **asdict(scenario),
                    **metrics,
                }
                rows.append(row)
                print(
                    f"scenario={scenario.name:>10s} repeat={repeat:02d} frac={frac:.2f} "
                    f"gt={gt:.3f} D_NTS_E={row['D_NTS_E']:.3f} "
                    f"RTD={row['RTD']:.3f} SRTD={row['SRTD']:.3f} "
                    f"RTD_lite={row['RTD_lite']:.3f} SRTD_lite={row['SRTD_lite']:.3f}"
                )

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(out_dir, "raw_results.csv"), index=False)
    return df


# ---------------------------------------------------------------------
# Analysis utilities
# ---------------------------------------------------------------------

def safe_corr(x: np.ndarray, y: np.ndarray, method: str) -> float:
    mask = np.isfinite(x) & np.isfinite(y)
    x, y = x[mask], y[mask]
    if len(x) < 3 or np.std(x) <= 1e-12 or np.std(y) <= 1e-12:
        return np.nan
    if method == "spearman":
        return float(spearmanr(x, y).statistic)
    if method == "kendall":
        return float(kendalltau(x, y).statistic)
    if method == "pearson":
        return float(pearsonr(x, y).statistic)
    raise ValueError(method)


def pooled_correlation_with_ground_truth(df: pd.DataFrame, metric_cols: Sequence[str]) -> pd.DataFrame:
    """
    Compute only the pooled cross-scenario correlation.

    This is the main quantity for the reviewer question: after mixing all
    heterogeneous scenarios, does each metric still rank the planted structural
    perturbations correctly?
    """
    rows = []
    for metric in metric_cols:
        rows.append({
            "group": "POOLED_ALL_SCENARIOS",
            "metric": metric,
            "n": int(len(df)),
            "spearman_rho": safe_corr(df[metric].to_numpy(), df["gt_hierarchy_distance"].to_numpy(), "spearman"),
            "kendall_tau": safe_corr(df[metric].to_numpy(), df["gt_hierarchy_distance"].to_numpy(), "kendall"),
            "pearson_r": safe_corr(df[metric].to_numpy(), df["gt_hierarchy_distance"].to_numpy(), "pearson"),
        })
    return pd.DataFrame(rows)


def metric_to_metric_correlation(df: pd.DataFrame) -> pd.DataFrame:
    pairs = [
        ("D_NTS_E", "RTD"),
        ("D_NTS_E", "SRTD"),
        ("D_NTS_E", "RTD_lite"),
        ("D_NTS_E", "SRTD_lite"),
        ("D_NTS_M", "RTD"),
        ("D_NTS_M", "SRTD"),
        ("D_NTS_M", "RTD_lite"),
        ("D_NTS_M", "SRTD_lite"),
    ]
    rows = []
    for a, b in pairs:
        if a not in df.columns or b not in df.columns:
            continue
        rows.append({
            "metric_a": a,
            "metric_b": b,
            "spearman_rho": safe_corr(df[a].to_numpy(), df[b].to_numpy(), "spearman"),
            "kendall_tau": safe_corr(df[a].to_numpy(), df[b].to_numpy(), "kendall"),
            "pearson_r": safe_corr(df[a].to_numpy(), df[b].to_numpy(), "pearson"),
        })
    return pd.DataFrame(rows)


def pairwise_ranking_accuracy(df: pd.DataFrame, metric_cols: Sequence[str]) -> pd.DataFrame:
    gt = df["gt_hierarchy_distance"].to_numpy()
    rows = []
    for metric in metric_cols:
        values = df[metric].to_numpy()
        valid = np.isfinite(gt) & np.isfinite(values)
        t, v = gt[valid], values[valid]
        total, correct = 0, 0.0
        for i in range(len(t)):
            for j in range(i + 1, len(t)):
                if abs(t[i] - t[j]) <= 1e-6:
                    continue
                total += 1
                true_gap = t[i] - t[j]
                metric_gap = v[i] - v[j]
                if abs(metric_gap) <= 1e-12:
                    correct += 0.5
                elif np.sign(true_gap) == np.sign(metric_gap):
                    correct += 1.0
        rows.append({"metric": metric, "num_pairs": total, "pairwise_ranking_accuracy": correct / total if total else np.nan})
    return pd.DataFrame(rows)


def fixed_level_stability(df: pd.DataFrame, metric_cols: Sequence[str]) -> pd.DataFrame:
    rows = []
    for frac, sub in df.groupby("perturb_frac"):
        for metric in metric_cols:
            values = sub[metric].to_numpy()
            values = values[np.isfinite(values)]
            if len(values) == 0:
                continue
            mean = float(np.mean(values))
            std = float(np.std(values))
            rows.append({
                "perturb_frac": float(frac),
                "metric": metric,
                "mean": mean,
                "std": std,
                "coef_variation": float(std / (abs(mean) + 1e-12)),
            })
    return pd.DataFrame(rows)


# ---------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------

def make_scatter_plots(df: pd.DataFrame, out_dir: str, metric_cols: Sequence[str]) -> None:
    for metric in metric_cols:
        if metric not in df.columns:
            continue
        plt.figure(figsize=(7, 5))
        for scenario, sub in df.groupby("scenario"):
            plt.scatter(sub["gt_hierarchy_distance"], sub[metric], s=22, alpha=0.75, label=scenario)
        plt.xlabel("Ground-truth hierarchy distance")
        plt.ylabel(metric)
        plt.title(f"{metric} vs ground-truth hierarchy distance")
        plt.legend(fontsize=8, frameon=False)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"scatter_{metric}.png"), dpi=200)
        plt.close()


def make_correlation_bar_plot(corr_df: pd.DataFrame, out_dir: str) -> None:
    pooled = corr_df[corr_df["group"] == "POOLED_ALL_SCENARIOS"].copy()
    pooled = pooled.dropna(subset=["spearman_rho"])
    plt.figure(figsize=(8, 5))
    plt.bar(pooled["metric"], pooled["spearman_rho"])
    plt.ylabel("Pooled Spearman correlation with ground truth")
    plt.xticks(rotation=35, ha="right")
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "pooled_spearman_bar.png"), dpi=200)
    plt.close()


# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--out_dir", type=str, default="results_nts_rtd_srtd")
    parser.add_argument("--n_repeats", type=int, default=10)
    parser.add_argument("--base_seed", type=int, default=123)
    parser.add_argument("--q", type=float, default=0.90)
    parser.add_argument("--h_dim", type=int, default=1)
    args = parser.parse_args()

    metric_cols = ["D_NTS_E", "D_NTS_M", "RTD", "SRTD", "RTD_lite", "SRTD_lite"]

    df = run_experiment(
        out_dir=args.out_dir,
        n_repeats=args.n_repeats,
        base_seed=args.base_seed,
        q=args.q,
        h_dim=args.h_dim,
    )

    corr_df = pooled_correlation_with_ground_truth(df, metric_cols)
    metric_corr_df = metric_to_metric_correlation(df)
    ranking_df = pairwise_ranking_accuracy(df, metric_cols)
    stability_df = fixed_level_stability(df, metric_cols)

    corr_df.to_csv(os.path.join(args.out_dir, "pooled_correlation_with_ground_truth.csv"), index=False)
    metric_corr_df.to_csv(os.path.join(args.out_dir, "metric_to_metric_correlation.csv"), index=False)
    ranking_df.to_csv(os.path.join(args.out_dir, "pairwise_ranking_accuracy.csv"), index=False)
    stability_df.to_csv(os.path.join(args.out_dir, "fixed_level_stability.csv"), index=False)

    make_scatter_plots(df, args.out_dir, metric_cols)
    make_correlation_bar_plot(corr_df, args.out_dir)

    print("\n=== Pooled cross-scenario correlation with ground-truth hierarchy distance ===")
    print(corr_df.to_string(index=False))
    print("\n=== Metric-to-metric correlation ===")
    print(metric_corr_df.to_string(index=False))
    print("\n=== Cross-scenario pairwise ranking accuracy ===")
    print(ranking_df.to_string(index=False))
    print("\nSaved outputs:")
    print(f"- {os.path.join(args.out_dir, 'raw_results.csv')}")
    print(f"- {os.path.join(args.out_dir, 'pooled_correlation_with_ground_truth.csv')}")
    print(f"- {os.path.join(args.out_dir, 'metric_to_metric_correlation.csv')}")
    print(f"- {os.path.join(args.out_dir, 'pairwise_ranking_accuracy.csv')}")
    print(f"- {os.path.join(args.out_dir, 'fixed_level_stability.csv')}")
    print(f"- {os.path.join(args.out_dir, 'pooled_spearman_bar.png')}")


if __name__ == "__main__":
    main()
