"""experiment_sweep.py

Parallel sweep experiment comparing STAGE, LLE and Isomap 1-D embeddings on
Random Fourier-curve data. For each parameter setting we generate a fresh
synthetic dataset, embed it, and record the absolute Kendall-τ agreement with
the ground-truth ordering as well as wall-clock runtime.

New in this revision
--------------------
* Compare STAGE vs LLE vs Isomap (no t-SNE)
* Support multiple ambient dimensions and sample sizes in a single run
* Multi-core execution via `joblib.Parallel` (use `--n_jobs`)
* Nicer figures using seaborn's default theme (ggplot-like aesthetics)
* CLI quality-of-life: progress bar per-task, explicit random-seed list
* Re-organised plotting so that colour/hue encodes the method and X-axis
  shows the hyper-parameter, improving readability when many settings are
  compared.

Usage (examples)
----------------
# replicate defaults (10 MC repeats) using all CPU cores
python experiment_sweep.py

# sweep k for all three methods, d=200, n=900
python experiment_sweep.py --n_iter 10 --k 30 50 70 --lle_k 30 50 70 --isomap_k 30 50 70

# sweep across multiple n and d in one go
python experiment_sweep.py --n_iter 5 --k 50 --lle_k 50 --isomap_k 50 \
                           --n_points 500 1000 --ambient_dim 100 200

Dependencies
------------
* numpy, pandas, matplotlib, seaborn, joblib, tqdm, scikit-learn, scipy
* stage3 (stage_embedding, evaluate_kendall_abs)
* curve.RandomFourierCurve  (data generator)

"""
from __future__ import annotations
import argparse, time, pathlib, itertools, os

import numpy as np
import pandas as pd
from scipy.stats import rankdata
from joblib import Parallel, delayed
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from curve import RandomFourierCurve            # your generator
from stage3 import stage_embedding, evaluate_kendall_abs
from sklearn.manifold import LocallyLinearEmbedding, Isomap   # CHANGED

sns.set_theme(context="talk", style="whitegrid")

# ---------------------------------------------------------------------- #
# Dataset generation
# ---------------------------------------------------------------------- #

def gen_dataset(seed: int,
                n_points: int = 900,
                span: float = 0.25,
                noise_sd: float = 2.0,
                ambient_dim: int = 200):      # NEW
    """Return (X, true_order) sampled from a random Fourier curve."""
    rng = np.random.default_rng(seed)
    curve = RandomFourierCurve(d=ambient_dim, K=10, alpha=2.3,
                               span=span, seed=seed)  # CHANGED
    smooth = curve.stretch_to_curvature(kappa_max=2.0)
    t, _ = smooth.unit_speed_grid(n_points)
    X = smooth.c(t) + rng.normal(scale=noise_sd, size=(n_points, smooth.d))
    true_order = rankdata(t, method="average")
    return X.astype(np.float32), true_order.astype(np.int32)

# ---------------------------------------------------------------------- #
# Embedding runners
# ---------------------------------------------------------------------- #

def run_stage(X: np.ndarray, true_order: np.ndarray, k: int, seed: int):
    start = time.perf_counter()
    _, ord_hat = stage_embedding(X, k=k, pca_full_dim=True, embedding="linreg")
    runtime = time.perf_counter() - start
    score = evaluate_kendall_abs(ord_hat, true_order)
    return runtime, score


def run_lle(X: np.ndarray, true_order: np.ndarray,
            n_neighbors: int, seed: int):
    """1D LLE and Kendall-τ against true order."""
    start = time.perf_counter()
    lle = LocallyLinearEmbedding(
        n_neighbors=n_neighbors,
        n_components=1,
        random_state=seed
    )
    y = lle.fit_transform(X).ravel()
    runtime = time.perf_counter() - start
    score = evaluate_kendall_abs(rankdata(y, method="average"), true_order)
    return runtime, score


def run_isomap(X: np.ndarray, true_order: np.ndarray,
               n_neighbors: int, seed: int):
    """1D Isomap and Kendall-τ against true order."""
    start = time.perf_counter()
    iso = Isomap(
        n_neighbors=n_neighbors,
        n_components=1,
    )
    y = iso.fit_transform(X).ravel()
    runtime = time.perf_counter() - start
    score = evaluate_kendall_abs(rankdata(y, method="average"), true_order)
    return runtime, score

# ---------------------------------------------------------------------- #
# Single job executed in parallel
# ---------------------------------------------------------------------- #

def _one_run(seed: int,
             k_vals_stage: list[int],
             k_vals_lle: list[int],
             k_vals_isomap: list[int],
             n_points: int,
             noise_sd: float,
             ambient_dim: int):
    """
    One dataset for a given (seed, n_points, ambient_dim), and all
    hyper-parameters for STAGE, LLE, and Isomap.
    """
    X, true_order = gen_dataset(seed,
                                n_points=n_points,
                                noise_sd=noise_sd,
                                ambient_dim=ambient_dim)
    recs = []
    # STAGE
    for k in k_vals_stage:
        rt, sc = run_stage(X, true_order, k=k, seed=seed)
        recs.append(dict(method="STAGE", param=k, seed=seed,
                         runtime=rt, kendall=sc,
                         n_points=n_points, ambient_dim=ambient_dim))

    # LLE
    for k in k_vals_lle:
        rt, sc = run_lle(X, true_order, n_neighbors=k, seed=seed)
        recs.append(dict(method="LLE", param=k, seed=seed,
                         runtime=rt, kendall=sc,
                         n_points=n_points, ambient_dim=ambient_dim))

    # Isomap
    for k in k_vals_isomap:
        rt, sc = run_isomap(X, true_order, n_neighbors=k, seed=seed)
        recs.append(dict(method="Isomap", param=k, seed=seed,
                         runtime=rt, kendall=sc,
                         n_points=n_points, ambient_dim=ambient_dim))

    return recs

# ---------------------------------------------------------------------- #
# Main
# ---------------------------------------------------------------------- #

def main():
    parser = argparse.ArgumentParser(
        "Monte-Carlo sweep for STAGE vs LLE vs Isomap"
    )
    parser.add_argument("--n_iter", type=int, default=10,
                        help="Number of Monte-Carlo repetitions per (n,d).")
    parser.add_argument("--noise_sd", type=float, default=2.0,
                        help="Noise standard deviation.")
    parser.add_argument("--k", type=int, nargs="+", default=[50],
                        help="Neighbourhood sizes for STAGE.")
    parser.add_argument("--lle_k", type=int, nargs="+", default=None,
                        help="Neighbourhood sizes for LLE "
                             "(default: same as --k).")
    parser.add_argument("--isomap_k", type=int, nargs="+", default=None,
                        help="Neighbourhood sizes for Isomap "
                             "(default: same as --k).")
    parser.add_argument("--n_points", type=int, nargs="+", default=[900],
                        help="Number(s) of points per synthetic dataset.")
    parser.add_argument("--ambient_dim", type=int, nargs="+", default=[200],
                        help="Ambient dimensionality (d) of the curve.")
    parser.add_argument("--n_jobs", type=int, default=os.cpu_count(),
                        help="Parallel workers (default: all cores).")
    parser.add_argument("--seed0", type=int, default=13,
                        help="Seed for drawing the list of MC seeds.")
    args = parser.parse_args()

    # Fill in defaults for LLE / Isomap if not given
    lle_k_vals = args.lle_k if args.lle_k is not None else args.k
    isomap_k_vals = args.isomap_k if args.isomap_k is not None else args.k

    rng = np.random.default_rng(args.seed0)

    # Build job list: one job per (seed, n_points, ambient_dim)
    jobs = []
    for d in args.ambient_dim:
        for n in args.n_points:
            seeds_for_pair = rng.choice(
                10_000_000, size=args.n_iter, replace=False
            ).tolist()
            for sd in seeds_for_pair:
                jobs.append((sd, n, d))

    print(f"Running {len(jobs)} total dataset embeddings "
          f"across {args.n_jobs} job(s)…")
    par = Parallel(n_jobs=args.n_jobs, verbose=0, prefer="processes")
    results = par(
        delayed(_one_run)(
            sd, args.k, lle_k_vals, isomap_k_vals, n, args.noise_sd, d
        )
        for (sd, n, d) in tqdm(jobs, desc="MC runs")
    )

    # flatten
    df = pd.DataFrame(itertools.chain.from_iterable(results))

    outdir = pathlib.Path("figures")
    outdir.mkdir(exist_ok=True)

    # Save one combined CSV with all (n,d) combos
    df.to_csv(
        f"results_noise{args.noise_sd}_multi_nd.csv",
        index=False
    )
    print(f"Saved raw results to results_noise{args.noise_sd}_multi_nd.csv")

    # ------------------------------------------------------------------ #
    # Plotting: one set of figures per (n_points, ambient_dim) pair
    # ------------------------------------------------------------------ #

    METHOD_ORDER = ["STAGE", "LLE", "Isomap"]
    METHOD_PALETTE = sns.color_palette("Set2", n_colors=len(METHOD_ORDER))
    METHOD_PALETTE_DICT = dict(zip(METHOD_ORDER, METHOD_PALETTE))

    for (n, d), df_sub in df.groupby(["n_points", "ambient_dim"]):
        df_plot = df_sub.copy()
        df_plot["param_cat"] = df_plot["param"].astype(str)
        df_plot = df_plot.sort_values(by="param")

        # ------------------------- #
        # Plot 1: Kendall-τ boxplot
        # ------------------------- #
        plt.figure(figsize=(12, 6))
        sns.boxplot(
            data=df_plot,
            x="param_cat", y="kendall",
            hue="method",
            hue_order=METHOD_ORDER,              # enforce consistent order
            palette=METHOD_PALETTE_DICT,         # enforce consistent colors
            showmeans=True,
            meanprops={
                "marker": "o",
                "markerfacecolor": "white",
                "markeredgecolor": "black"
            },
        )
        plt.xlabel("Neighbourhood size k")
        plt.ylabel("|Kendall-τ|")
        plt.title(
            f"Embedding Quality vs. k "
            f"({args.n_iter} MC Repetitions)\n"
            f"Noise SD: {args.noise_sd}  n: {n}  d: {d}"
        )
        plt.legend(title="Method")
        plt.tight_layout()
        plt.savefig(
            outdir / f"kendall_boxplot_noise{args.noise_sd}_n{n}_d{d}.png",
            dpi=300
        )
        plt.close()

        # ------------------------- #
        # Plot 2: Runtime comparison
        # ------------------------- #
        df_stage = df_plot[df_plot["method"] == "STAGE"]
        df_lle = df_plot[df_plot["method"] == "LLE"]
        df_iso = df_plot[df_plot["method"] == "Isomap"]

        unique_params_stage = (
            sorted(df_stage["param_cat"].unique(), key=int)
            if not df_stage.empty else []
        )
        unique_params_lle = (
            sorted(df_lle["param_cat"].unique(), key=int)
            if not df_lle.empty else []
        )
        unique_params_iso = (
            sorted(df_iso["param_cat"].unique(), key=int)
            if not df_iso.empty else []
        )

        fig_runtime, axes_runtime = plt.subplots(
            1, 3, figsize=(20, 6), sharey=False
        )

        fig_runtime.suptitle(
            f"Runtime Comparison ({args.n_iter} MC Repetitions, Avg ±1 SD)\n"
            f"Noise SD: {args.noise_sd}  n: {n}  d: {d}",
            fontsize=16
        )

        # STAGE
        ax0 = axes_runtime[0]
        if not df_stage.empty:
            sns.barplot(
                data=df_stage, x="param_cat", y="runtime", ax=ax0,
                color=METHOD_PALETTE_DICT["STAGE"],
                order=unique_params_stage,
                errorbar="sd", capsize=.1
            )
            ax0.set_xlabel("k (Neighbourhood size)")
            ax0.set_ylabel("Runtime (s)")
            ax0.set_title("STAGE")
        else:
            ax0.text(0.5, 0.5, "No STAGE data",
                     ha='center', va='center', transform=ax0.transAxes)
            ax0.set_title("STAGE (No data)")

        # LLE
        ax1 = axes_runtime[1]
        if not df_lle.empty:
            sns.barplot(
                data=df_lle, x="param_cat", y="runtime", ax=ax1,
                color=METHOD_PALETTE_DICT["LLE"],
                order=unique_params_lle,
                errorbar="sd", capsize=.1
            )
            ax1.set_xlabel("k (Neighbourhood size)")
            ax1.set_ylabel("Runtime (s)")
            ax1.set_title("LLE")
        else:
            ax1.text(0.5, 0.5, "No LLE data",
                     ha='center', va='center', transform=ax1.transAxes)
            ax1.set_title("LLE (No data)")

        # Isomap
        ax2 = axes_runtime[2]
        if not df_iso.empty:
            sns.barplot(
                data=df_iso, x="param_cat", y="runtime", ax=ax2,
                color=METHOD_PALETTE_DICT["Isomap"],
                order=unique_params_iso,
                errorbar="sd", capsize=.1
            )
            ax2.set_xlabel("k (Neighbourhood size)")
            ax2.set_ylabel("Runtime (s)")
            ax2.set_title("Isomap")
        else:
            ax2.text(0.5, 0.5, "No Isomap data",
                     ha='center', va='center', transform=ax2.transAxes)
            ax2.set_title("Isomap (No data)")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(
            outdir / f"runtime_comparison_noise{args.noise_sd}_n{n}_d{d}.png",
            dpi=300
        )
        plt.close(fig_runtime)

    print("Figures saved to", outdir.resolve())


if __name__ == "__main__":
    main()