# Simplified scaling benchmark: fixed noise, vary n and d
import time
import numpy as np
import pandas as pd
from scipy.stats import rankdata
from sklearn.manifold import TSNE
import umap

from curve import RandomFourierCurve
from stage3 import stage_embedding, evaluate_kendall_abs
from experiment_utils import fiedler_permutation, spectral_ordering

# ------------------ config ------------------
N_RUNS    = 10
NS        = [500, 1000, 2000]
DS        = [50, 100, 200]
NOISE_SD  = 1.0
TSNE_PERP = 100  # safe for n >= 500; adjust if using smaller n

# ---------------- algorithms ----------------
def order_umap(X, d):
    y = umap.UMAP(n_components=1).fit_transform(X).ravel()
    return rankdata(y)

def order_fiedler(X, d):
    _, idx = fiedler_permutation(X, sigma=np.sqrt(d) * NOISE_SD)
    return idx

def order_tsne(X, d, n):
    perp = min(TSNE_PERP, max(5, n // 3))  # be safe if n changes
    y = TSNE(n_components=1, perplexity=perp, init="random", learning_rate="auto").fit_transform(X).ravel()
    return rankdata(y)

def order_recanati(X, d):
    scores = spectral_ordering(X, sigma=np.sqrt(d) * NOISE_SD)
    return np.argsort(scores)

def order_stage(X, d):
    _, idx = stage_embedding(X, k=50, pca_full_dim=True, embedding="linreg")
    return idx

ALGS = [
    ("UMAP",      lambda X, d, n: order_umap(X, d)),
    ("Fiedler",   lambda X, d, n: order_fiedler(X, d)),
    ("t-SNE",     lambda X, d, n: order_tsne(X, d, X.shape[0])),
    ("Recanati",  lambda X, d, n: order_recanati(X, d)),
    ("STAGE",     lambda X, d, n: order_stage(X, d)),
]

# -------------- data generator --------------
def make_dataset(d, n_points, seed):
    curve = RandomFourierCurve(d=d, K=10, alpha=2.3, span=0.25, seed=seed)
    smooth = curve.stretch_to_curvature(kappa_max=2.0)
    t, _   = smooth.unit_speed_grid(n_points)
    np.random.shuffle(t)
    X = smooth.c(t)
    X = X + np.random.normal(scale=NOISE_SD, size=X.shape)
    true_order = rankdata(t, method="average")
    return X, true_order

# --------------- main loop ------------------
rows = []
for d in DS:
    for n in NS:
        taus = np.empty((N_RUNS, len(ALGS)))
        times = np.empty((N_RUNS, len(ALGS)))

        start_block = time.time()
        for r in range(N_RUNS):
            X, true_order = make_dataset(d, n, seed=r)
            for j, (name, fn) in enumerate(ALGS):
                t0 = time.perf_counter()
                idx = fn(X, d, n)
                times[r, j] = time.perf_counter() - t0
                taus[r, j]  = evaluate_kendall_abs(true_order, idx)

        tau_mean = taus.mean(axis=0)
        tau_sd   = taus.std(axis=0)
        t_mean   = times.mean(axis=0)

        for j, (name, _) in enumerate(ALGS):
            rows.append({
                "d": d,
                "n": n,
                "alg": name,
                "kendall_tau_mean": tau_mean[j],
                "kendall_tau_sd": tau_sd[j],
                "runtime_s_mean": t_mean[j],
            })

        print(f"[d={d}, n={n}] finished in {time.time() - start_block:.2f}s")

# --------------- reporting ------------------
df = pd.DataFrame(rows).sort_values(["d", "n", "alg"]).reset_index(drop=True)

# Pretty columns for Markdown
def fmt_tau(m, s):  # percent with sd
    return f"{m*100:.2f} ({s*100:.2f})"

def fmt_time(t):
    return f"{t:.4f}"

df_report = df.copy()
df_report["τ (%, sd)"] = [fmt_tau(m, s) for m, s in zip(df["kendall_tau_mean"], df["kendall_tau_sd"])]
df_report["time (s)"]  = [fmt_time(t) for t in df["runtime_s_mean"]]
df_report = df_report[["d", "n", "alg", "τ (%, sd)", "time (s)"]]

# One combined table (nice for paper/docs)
try:
    print(df_report.to_markdown(index=False))
except Exception:
    print(df_report.to_string(index=False))

# Also: one table per (d, n)
for (d, n), sub in df_report.groupby(["d", "n"]):
    print(f"\n### d={d}, n={n}")
    try:
        print(sub.drop(columns=["d", "n"]).to_markdown(index=False))
    except Exception:
        print(sub.drop(columns=["d", "n"]).to_string(index=False))