from __future__ import annotations

import argparse
import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances

warnings.filterwarnings("ignore", category=FutureWarning)

try:
    from gtda.homology import VietorisRipsPersistence

    GTDA_AVAILABLE = True
except Exception:
    GTDA_AVAILABLE = False


def zscore_features(X: np.ndarray, eps: float = 1e-9) -> np.ndarray:
    return (X - X.mean(axis=0, keepdims=True)) / (X.std(axis=0, keepdims=True) + eps)


def upper_triangular_values(D: np.ndarray) -> np.ndarray:
    return D[np.triu_indices_from(D, k=1)]


def normalize_by_quantile(D: np.ndarray, q: float = 0.90) -> np.ndarray:
    vals = upper_triangular_values(D)
    vals = vals[vals > 0]
    scale = float(np.quantile(vals, q)) if vals.size > 0 else 1.0
    return D.copy() if scale <= 1e-12 else D / scale


def mst_edges_and_weight(D: np.ndarray) -> Tuple[List[Tuple[int, int]], np.ndarray, float]:
    mst = minimum_spanning_tree(D).tocoo()
    edges: List[Tuple[int, int]] = []
    weights: List[float] = []
    for i, j, w in zip(mst.row, mst.col, mst.data):
        a, b = int(i), int(j)
        if a > b:
            a, b = b, a
        edges.append((a, b))
        weights.append(float(w))
    weights_arr = np.asarray(weights, dtype=float)
    return edges, weights_arr, float(weights_arr.sum())


def safe_spearman(x: np.ndarray, y: np.ndarray) -> float:
    if len(x) < 2 or np.std(x) <= 1e-12 or np.std(y) <= 1e-12:
        return np.nan
    return float(spearmanr(x, y).statistic)


def tree_adjacency(
    n: int,
    edges: Sequence[Tuple[int, int]],
    weights: Sequence[float],
) -> List[List[Tuple[int, float]]]:
    adj: List[List[Tuple[int, float]]] = [[] for _ in range(n)]
    for (a, b), w in zip(edges, weights):
        adj[a].append((b, float(w)))
        adj[b].append((a, float(w)))
    return adj


def bottleneck_path_value(
    adj: List[List[Tuple[int, float]]],
    source: int,
    target: int,
) -> float:
    stack = [(source, -1, 0.0)]
    while stack:
        node, parent, current_max = stack.pop()
        if node == target:
            return current_max
        for nxt, w in adj[node]:
            if nxt != parent:
                stack.append((nxt, node, max(current_max, w)))
    raise RuntimeError("Tree path not found")


def bottleneck_values(
    n: int,
    edges: Sequence[Tuple[int, int]],
    weights: Sequence[float],
    pairs: Sequence[Tuple[int, int]],
) -> np.ndarray:
    adj = tree_adjacency(n, edges, weights)
    return np.asarray([bottleneck_path_value(adj, a, b) for a, b in pairs], dtype=float)


def nts_metrics(D1: np.ndarray, D2: np.ndarray) -> Dict[str, float]:
    n = D1.shape[0]
    edges1, weights1, _ = mst_edges_and_weight(D1)
    edges2, weights2, _ = mst_edges_and_weight(D2)
    core_edges = sorted(set(edges1).union(set(edges2)))

    x_e = np.asarray([D1[a, b] for a, b in core_edges], dtype=float)
    y_e = np.asarray([D2[a, b] for a, b in core_edges], dtype=float)
    x_m = bottleneck_values(n, edges1, weights1, core_edges)
    y_m = bottleneck_values(n, edges2, weights2, core_edges)

    return {
        "NTS-E": safe_spearman(x_e, y_e),
        "NTS-M": safe_spearman(x_m, y_m),
    }


def rtd_lite_metrics(D1: np.ndarray, D2: np.ndarray, q: float = 0.90) -> Dict[str, float]:
    W1 = normalize_by_quantile(D1, q)
    W2 = normalize_by_quantile(D2, q)
    Wmin = np.minimum(W1, W2)
    Wmax = np.maximum(W1, W2)

    _, _, mst1 = mst_edges_and_weight(W1)
    _, _, mst2 = mst_edges_and_weight(W2)
    _, _, mst_min = mst_edges_and_weight(Wmin)
    _, _, mst_max = mst_edges_and_weight(Wmax)

    rtd_sym = 0.5 * ((mst1 - mst_min) + (mst2 - mst_min))
    srtd_lite = mst_max - mst_min
    return {
        "RTD-lite": float(rtd_sym),
        "SRTD-lite": float(srtd_lite),
    }


def plus_matrix(D: np.ndarray, inf_value: float) -> np.ndarray:
    P = D.copy()
    iu = np.triu_indices_from(P, k=1)
    P[iu] = inf_value
    return P


def sanitize_auxiliary_matrix(M: np.ndarray) -> np.ndarray:
    finite_vals = M[np.isfinite(M)]
    max_finite = float(np.max(finite_vals)) if finite_vals.size else 1.0
    large = max(10.0 * max_finite, max_finite + 10.0, 1e6 if max_finite == 0 else 0.0)
    out = M.copy()
    out[~np.isfinite(out)] = large
    np.fill_diagonal(out, 0.0)
    return out


def build_mmin(W: np.ndarray, Wt: np.ndarray) -> np.ndarray:
    n = W.shape[0]
    real_max = max(float(np.max(W)), float(np.max(Wt)), 1.0)
    inf = 10.0 * real_max + 10.0
    Wmin = np.minimum(W, Wt)
    Wplus = plus_matrix(W, inf)
    z = np.zeros((n, 1))
    inf_col = np.full((n, 1), inf)
    top = np.concatenate([W, Wplus.T, z], axis=1)
    mid = np.concatenate([Wplus, Wmin, inf_col], axis=1)
    bot = np.concatenate([z.T, inf_col.T, np.zeros((1, 1))], axis=1)
    return sanitize_auxiliary_matrix(np.concatenate([top, mid, bot], axis=0))


def build_msym(W: np.ndarray, Wt: np.ndarray) -> np.ndarray:
    n = W.shape[0]
    real_max = max(float(np.max(W)), float(np.max(Wt)), 1.0)
    inf = 10.0 * real_max + 10.0
    Wmin = np.minimum(W, Wt)
    Wmax = np.maximum(W, Wt)
    Wmax_plus = plus_matrix(Wmax, inf)
    z = np.zeros((n, 1))
    inf_col = np.full((n, 1), inf)
    top = np.concatenate([Wmax, Wmax_plus.T, z], axis=1)
    mid = np.concatenate([Wmax_plus, Wmin, inf_col], axis=1)
    bot = np.concatenate([z.T, inf_col.T, np.zeros((1, 1))], axis=1)
    return sanitize_auxiliary_matrix(np.concatenate([top, mid, bot], axis=0))


def persistence_sum_precomputed(M: np.ndarray, h_dim: int = 1) -> float:
    if not GTDA_AVAILABLE:
        raise ImportError("giotto-tda is required for full RTD/SRTD.")
    vr = VietorisRipsPersistence(homology_dimensions=[h_dim], metric="precomputed")
    diagrams = vr.fit_transform([M])
    if diagrams.size == 0 or diagrams[0].size == 0:
        return 0.0
    bars = diagrams[0][diagrams[0][:, 2] == h_dim]
    if bars.shape[0] == 0:
        return 0.0
    lengths = bars[:, 1] - bars[:, 0]
    lengths = lengths[np.isfinite(lengths) & (lengths > 1e-12)]
    return float(np.sum(lengths))


def full_rtd_srtd_metrics(D1: np.ndarray, D2: np.ndarray, q: float = 0.90, h_dim: int = 1) -> Dict[str, float]:
    W1 = normalize_by_quantile(D1, q)
    W2 = normalize_by_quantile(D2, q)
    rtd = 0.5 * (
        persistence_sum_precomputed(build_mmin(W1, W2), h_dim=h_dim)
        + persistence_sum_precomputed(build_mmin(W2, W1), h_dim=h_dim)
    )
    srtd = persistence_sum_precomputed(build_msym(W1, W2), h_dim=h_dim)
    return {
        "RTD": float(rtd),
        "SRTD": float(srtd),
    }


def compute_all_metrics(X: np.ndarray, Y: np.ndarray, q: float = 0.90) -> Dict[str, float]:
    X = zscore_features(X)
    Y = zscore_features(Y)
    D1 = pairwise_distances(X, metric="euclidean")
    D2 = pairwise_distances(Y, metric="euclidean")
    out: Dict[str, float] = {}
    out.update(nts_metrics(D1, D2))
    out.update(full_rtd_srtd_metrics(D1, D2, q=q, h_dim=1))
    out.update(rtd_lite_metrics(D1, D2, q=q))
    return out


@dataclass(frozen=True)
class ExperimentConfig:
    dim: int = 50
    n_clusters: int = 8
    points_per_cluster: int = 10
    separation: float = 8.0
    sigma: float = 0.5
    warp_strength: float = 12.0
    q: float = 0.90


def make_ring_centers(rng: np.random.Generator, n_clusters: int, dim: int, separation: float) -> np.ndarray:
    angles = np.linspace(0.0, 2.0 * np.pi, n_clusters, endpoint=False)
    centers = np.zeros((n_clusters, dim), dtype=float)
    centers[:, 0] = separation * np.cos(angles)
    centers[:, 1] = separation * np.sin(angles)
    if dim > 2:
        centers[:, 2:] = 0.35 * separation * rng.normal(size=(n_clusters, dim - 2))
    return centers


def make_structured_cloud(
    rng: np.random.Generator,
    cfg: ExperimentConfig,
) -> Tuple[np.ndarray, np.ndarray]:
    centers = make_ring_centers(rng, cfg.n_clusters, cfg.dim, cfg.separation)
    labels = np.repeat(np.arange(cfg.n_clusters), cfg.points_per_cluster)
    X = centers[labels] + cfg.sigma * rng.normal(size=(len(labels), cfg.dim))
    return X.astype(float), labels.astype(int)


def make_related_warped_cloud(
    rng: np.random.Generator,
    X: np.ndarray,
    labels: np.ndarray,
    cfg: ExperimentConfig,
) -> np.ndarray:
    Y = np.empty_like(X)
    dim = X.shape[1]
    for c in np.unique(labels):
        idx = np.where(labels == c)[0]
        Xc = X[idx]
        mu = Xc.mean(axis=0, keepdims=True)
        Q, _ = np.linalg.qr(rng.normal(size=(dim, dim)))
        scales = np.geomspace(1.0 / cfg.warp_strength, cfg.warp_strength, dim)
        rng.shuffle(scales)
        transform = Q @ np.diag(scales) @ Q.T
        shift = 0.10 * cfg.warp_strength * rng.normal(size=(1, dim))
        Y[idx] = mu + (Xc - mu) @ transform + shift + 0.05 * cfg.sigma * rng.normal(size=Xc.shape)
    return Y.astype(float)


def make_shell_cloud(rng: np.random.Generator, n_samples: int, dim: int) -> np.ndarray:
    X = rng.normal(size=(n_samples, dim))
    X /= np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
    radius = np.sqrt(dim) * (1.0 + 0.025 * rng.normal(size=(n_samples, 1)))
    return (radius * X).astype(float)


def generate_abcd(cfg: ExperimentConfig, seed: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    n = cfg.n_clusters * cfg.points_per_cluster
    A, labels = make_structured_cloud(rng, cfg)
    B = make_related_warped_cloud(rng, A, labels, cfg)
    C = make_shell_cloud(rng, n, cfg.dim)
    D = make_shell_cloud(rng, n, cfg.dim)
    return A, B, C, D, labels


def summarize_results(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    metrics = ["NTS-E", "NTS-M", "RTD", "SRTD", "RTD-lite", "SRTD-lite"]
    for metric in metrics:
        ab = df.loc[df["pair"] == "AB_related", metric].dropna().to_numpy()
        cd = df.loc[df["pair"] == "CD_unrelated", metric].dropna().to_numpy()
        ab_text = f"{np.mean(ab):.4f}±{np.std(ab, ddof=1):.4f}"
        cd_text = f"{np.mean(cd):.4f}±{np.std(cd, ddof=1):.4f}"
        if metric in {"NTS-E", "NTS-M"}:
            correct = float(np.mean(ab > cd))
        else:
            correct = float(np.mean(ab < cd))
        rows.append(
            {
                "metric": metric,
                "AB related mean±std": ab_text,
                "CD unrelated mean±std": cd_text,
                "correct ranking": f"{correct:.3f}",
            }
        )
    return pd.DataFrame(rows)


def visualize_abcd(
    A: np.ndarray,
    B: np.ndarray,
    C: np.ndarray,
    D: np.ndarray,
    labels: np.ndarray,
    out_path: str,
) -> None:
    all_points = np.vstack([A, B, C, D])
    coords = PCA(n_components=2, random_state=0).fit_transform(zscore_features(all_points))
    n = A.shape[0]
    A2 = coords[:n]
    B2 = coords[n : 2 * n]
    C2 = coords[2 * n : 3 * n]
    D2 = coords[3 * n : 4 * n]

    fig, axes = plt.subplots(2, 2, figsize=(8, 7), sharex=True, sharey=True)
    panels = [
        (axes[0, 0], A2, labels, "A"),
        (axes[0, 1], B2, labels, "B"),
        (axes[1, 0], C2, None, "C"),
        (axes[1, 1], D2, None, "D"),
    ]
    for ax, X2, lab, title in panels:
        if lab is None:
            ax.scatter(X2[:, 0], X2[:, 1], s=18, alpha=0.75, color="gray")
        else:
            ax.scatter(X2[:, 0], X2[:, 1], c=lab, s=18, alpha=0.85, cmap="tab10")
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    plt.savefig(out_path, dpi=240)
    plt.close()


def visualize_pair_correspondence(
    A: np.ndarray,
    B: np.ndarray,
    C: np.ndarray,
    D: np.ndarray,
    labels: np.ndarray,
    out_path: str,
    max_lines: int,
    seed: int,
) -> None:
    rng = np.random.default_rng(seed)
    n = A.shape[0]
    chosen = rng.choice(n, size=min(max_lines, n), replace=False)

    AB_2d = PCA(n_components=2, random_state=0).fit_transform(zscore_features(np.vstack([A, B])))
    CD_2d = PCA(n_components=2, random_state=0).fit_transform(zscore_features(np.vstack([C, D])))

    A2 = AB_2d[:n]
    B2 = AB_2d[n:]
    C2 = CD_2d[:n]
    D2 = CD_2d[n:]

    fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))

    ax = axes[0]
    ax.scatter(A2[:, 0], A2[:, 1], c=labels, s=20, alpha=0.85, cmap="tab10", marker="o")
    ax.scatter(B2[:, 0], B2[:, 1], c=labels, s=20, alpha=0.50, cmap="tab10", marker="^")
    for i in chosen:
        ax.plot([A2[i, 0], B2[i, 0]], [A2[i, 1], B2[i, 1]], linewidth=0.5, alpha=0.25, color="black")
    ax.set_title("A → B")
    ax.set_xticks([])
    ax.set_yticks([])

    ax = axes[1]
    ax.scatter(C2[:, 0], C2[:, 1], s=20, alpha=0.78, color="gray", marker="o")
    ax.scatter(D2[:, 0], D2[:, 1], s=20, alpha=0.50, color="darkorange", marker="^")
    for i in chosen:
        ax.plot([C2[i, 0], D2[i, 0]], [C2[i, 1], D2[i, 1]], linewidth=0.5, alpha=0.25, color="black")
    ax.set_title("C → D")
    ax.set_xticks([])
    ax.set_yticks([])

    plt.tight_layout()
    plt.savefig(out_path, dpi=240)
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--out_dir", type=str, default="outputs/abcd_clean")
    parser.add_argument("--seeds", type=int, default=20)
    parser.add_argument("--base_seed", type=int, default=10123)
    parser.add_argument("--max_lines", type=int, default=70)
    args = parser.parse_args()

    if not GTDA_AVAILABLE:
        raise ImportError("This script requires giotto-tda because it reports RTD and SRTD.")

    os.makedirs(args.out_dir, exist_ok=True)
    cfg = ExperimentConfig()
    rows = []
    seeds = [args.base_seed + i for i in range(args.seeds)]

    for seed in seeds:
        A, B, C, D, labels = generate_abcd(cfg, seed)
        rows.append({"seed": seed, "pair": "AB_related", **compute_all_metrics(A, B, q=cfg.q)})
        rows.append({"seed": seed, "pair": "CD_unrelated", **compute_all_metrics(C, D, q=cfg.q)})

    df = pd.DataFrame(rows)
    summary_df = summarize_results(df)
    df.to_csv(os.path.join(args.out_dir, "raw_results.csv"), index=False)
    summary_df.to_csv(os.path.join(args.out_dir, "results_table.csv"), index=False)

    A, B, C, D, labels = generate_abcd(cfg, seeds[0])
    visualize_abcd(A, B, C, D, labels, os.path.join(args.out_dir, "abcd_pca_visualization.png"))
    visualize_pair_correspondence(
        A,
        B,
        C,
        D,
        labels,
        os.path.join(args.out_dir, "pair_correspondence_visualization.png"),
        max_lines=args.max_lines,
        seed=seeds[0],
    )

    print(summary_df.to_string(index=False))
    print()
    print(f"Saved: {os.path.join(args.out_dir, 'raw_results.csv')}")
    print(f"Saved: {os.path.join(args.out_dir, 'results_table.csv')}")
    print(f"Saved: {os.path.join(args.out_dir, 'abcd_pca_visualization.png')}")
    print(f"Saved: {os.path.join(args.out_dir, 'pair_correspondence_visualization.png')}")


if __name__ == "__main__":
    main()
