"""experiment_sweep.py

Parallel sweep experiment comparing STAGE and t‑SNE 1‑D embeddings on
Random Fourier‑curve data.  For each parameter setting we generate a fresh
synthetic dataset, embed it, and record the absolute Kendall‑τ agreement with
the ground‑truth ordering as well as wall‑clock runtime.

New in this revision
--------------------
* **Multi‑core execution** via `joblib.Parallel` (use `--n_jobs`)
* **Nicer figures** using seaborn's default theme (ggplot‑like aesthetics)
* **CLI quality‑of‑life**: progress bar per‑task, explicit random‑seed list
* **Re‑organised plotting** so that colour/hue encodes the method and X‑axis
  shows the hyper‑parameter, improving readability when many settings are
  compared.

Usage (examples)
----------------
# replicate defaults (10 MC repeats) using all CPU cores
python experiment_sweep.py

# exhaustive sweep, stick to 8 workers, 5 repeats
python experiment_sweep.py --n_iter 5 --k 30 50 70 --perplexity 50 100 200 \\
                           --n_jobs 8

Dependencies
------------
* numpy, pandas, matplotlib, seaborn, joblib, tqdm, scikit‑learn, scipy
* stage3 (stage_embedding, evaluate_kendall_abs)
* curve.RandomFourierCurve  (data generator)

"""
from __future__ import annotations
import argparse, time, pathlib, itertools, random, os

import numpy as np
import pandas as pd
from scipy.stats import rankdata
from joblib import Parallel, delayed
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from curve import RandomFourierCurve            # your generator
from stage3 import stage_embedding, evaluate_kendall_abs
from sklearn.manifold import TSNE

sns.set_theme(context="talk", style="whitegrid")

# ---------------------------------------------------------------------- #
# Dataset generation
# ---------------------------------------------------------------------- #

def gen_dataset(seed: int, n_points: int = 900, span: float = 0.25,
                noise_sd: float = 5.0): 
    """Return (X, true_order) sampled from a random Fourier curve."""
    rng = np.random.default_rng(seed)
    curve = RandomFourierCurve(d=200, K=10, alpha=2.3, span=span, seed=seed)
    smooth = curve.stretch_to_curvature(kappa_max=2.0)
    t, _ = smooth.unit_speed_grid(n_points)
    X = smooth.c(t) + rng.normal(scale=noise_sd, size=(n_points, smooth.d))
    true_order = rankdata(t, method="average")
    return X.astype(np.float32), true_order.astype(np.int32)

# ---------------------------------------------------------------------- #
# Embedding runners
# ---------------------------------------------------------------------- #

def run_stage(X: np.ndarray, true_order: np.ndarray, k: int, seed: int):
    start = time.perf_counter()
    _, ord_hat = stage_embedding(X, k=k, pca_full_dim=True, embedding="linreg")
    runtime = time.perf_counter() - start
    score = evaluate_kendall_abs(ord_hat, true_order)
    return runtime, score


def run_tsne(X: np.ndarray, true_order: np.ndarray, perplexity: int, seed: int):
    start = time.perf_counter()
    y = TSNE(n_components=1, perplexity=perplexity).fit_transform(X).ravel()
    runtime = time.perf_counter() - start
    score = evaluate_kendall_abs(rankdata(y, method="average"), true_order)
    return runtime, score

# ---------------------------------------------------------------------- #
# Single job executed in parallel
# ---------------------------------------------------------------------- #

def _one_run(seed: int, k_vals: list[int], perplex_vals: list[int], n_points: int, noise_sd: float):
    X, true_order = gen_dataset(seed, n_points=n_points, noise_sd=noise_sd)
    recs = []
    # STAGE
    for k in k_vals:
        rt, sc = run_stage(X, true_order, k=k, seed=seed)
        recs.append(dict(method="STAGE", param=k, seed=seed,
                         runtime=rt, kendall=sc))
    # t‑SNE
    for pp in perplex_vals:
        rt, sc = run_tsne(X, true_order, perplexity=pp, seed=seed)
        recs.append(dict(method="t-SNE", param=pp, seed=seed,
                         runtime=rt, kendall=sc))
    return recs

# ---------------------------------------------------------------------- #
# Main
# ---------------------------------------------------------------------- #

def main():
    parser = argparse.ArgumentParser("Monte‑Carlo sweep for STAGE vs t‑SNE")
    parser.add_argument("--n_iter", type=int, default=10,
                        help="Number of Monte‑Carlo repetitions.")
    parser.add_argument("--noise_sd", type=float, default=5.0,
                        help="Noise standard deviation.")
    parser.add_argument("--k", type=int, nargs="+", default=[50],
                        help="Neighbourhood sizes for STAGE.")
    parser.add_argument("--perplexity", type=int, nargs="+", default=[100],
                        help="Perplexities for t‑SNE.")
    parser.add_argument("--n_points", type=int, default=900,
                        help="Number of points per synthetic dataset.")
    parser.add_argument("--n_jobs", type=int, default=os.cpu_count(),
                        help="Parallel workers (default: all cores).")
    parser.add_argument("--seed0", type=int, default=13,
                        help="Seed for drawing the list of MC seeds.")
    args = parser.parse_args()

    rng = np.random.default_rng(args.seed0)
    seeds = rng.choice(10_000_000, size=args.n_iter, replace=False).tolist()

    print(f"Running {args.n_iter} repetitions across {args.n_jobs} job(s)…")
    par = Parallel(n_jobs=args.n_jobs, verbose=0, prefer="processes")
    results = par(delayed(_one_run)(sd, args.k, args.perplexity, args.n_points, args.noise_sd)
                  for sd in tqdm(seeds, desc="MC runs"))

    # flatten
    df = pd.DataFrame(itertools.chain.from_iterable(results))

    outdir = pathlib.Path("figures"); outdir.mkdir(exist_ok=True)
    df.to_csv(f"results_noise{args.noise_sd}_n{args.n_points}.csv", index=False)
    print(f"Saved raw results to results_noise{args.noise_sd}_n{args.n_points}.csv")

    # Create a plotting-specific DataFrame to avoid modifying the original df
    df_plot = df.copy()
    # Convert 'param' to string type for categorical plotting and correct ordering.
    # This ensures params like [10, 50, 100] are spaced evenly as categories
    # and sorted numerically.
    df_plot["param_cat"] = df_plot["param"].astype(str)
    
    # Get unique sorted categorical parameters for combined plot ordering if needed,
    # or let seaborn handle it if a single "param_cat" column is used across methods.
    # For boxplot, seaborn will create categories for all unique values in 'param_cat'.
    # We might want to ensure they are sorted numerically if not already.
    # Let's sort the DataFrame by param before plotting if 'param_cat' is used directly
    # This helps if seaborn's default order isn't numerical for string categories.
    df_plot = df_plot.sort_values(by="param")


    # ------------------------------------------------------------------ #
    # Plot 1: Kendall-τ boxplots (method hue, combined parameter on x)
    # ------------------------------------------------------------------ #
    plt.figure(figsize=(12, 6)) # Adjusted figure size for potentially more categories
    sns.boxplot(data=df_plot, x="param_cat", y="kendall", hue="method",
                palette="Set2", showmeans=True, meanprops={"marker":"o",
                "markerfacecolor":"white", "markeredgecolor":"black"})
    plt.xlabel("Hyper-parameter (k for STAGE, Perplexity for t-SNE)")
    plt.ylabel("|Kendall-τ|")
    plt.title(f"Embedding Quality vs. Hyper-parameter ({args.n_iter} MC Repetitions)\nNoise SD: {args.noise_sd} n: {args.n_points}")
    plt.legend(title="Method")
    plt.tight_layout()
    plt.savefig(outdir / f"kendall_boxplot_noise{args.noise_sd}_n{args.n_points}.png", dpi=300)
    plt.close()

    # ------------------------------------------------------------------ #
    # Plot 2: Runtime comparison (STAGE vs t-SNE in separate subplots)
    # ------------------------------------------------------------------ #
    # (Using the separate subplots version for runtime as discussed)
    
    df_stage = df_plot[df_plot["method"] == "STAGE"]
    df_tsne = df_plot[df_plot["method"] == "t-SNE"]

    # Get sorted unique categorical parameters for ordering plots
    unique_params_cat_stage = []
    if not df_stage.empty:
        unique_params_cat_stage = sorted(df_stage["param_cat"].unique(), key=int)
    
    unique_params_cat_tsne = []
    if not df_tsne.empty:
        unique_params_cat_tsne = sorted(df_tsne["param_cat"].unique(), key=int)

    palette_dict = {"STAGE": sns.color_palette("Set2")[0], "t-SNE": sns.color_palette("Set2")[1]}
    # Or use the same colors from before
    # color_stage = sns.color_palette("Set2")[0]
    # color_tsne = sns.color_palette("Set2")[1]


    fig_runtime, axes_runtime = plt.subplots(1, 2, figsize=(16, 6), sharey=False) # Runtimes on different scales
    
    fig_runtime.suptitle(f"Runtime Comparison ({args.n_iter} MC Repetitions, Avg ±1 SD)\nNoise SD: {args.noise_sd} n: {args.n_points}", fontsize=16)

    if not df_stage.empty:
        sns.barplot(data=df_stage, x="param_cat", y="runtime", ax=axes_runtime[0],
                    color=palette_dict["STAGE"], order=unique_params_cat_stage,
                    errorbar="sd", capsize=.1)
        axes_runtime[0].set_xlabel("k (Neighborhood size)")
        axes_runtime[0].set_ylabel("Runtime (s)")
        axes_runtime[0].set_title("STAGE")
    else:
        axes_runtime[0].text(0.5, 0.5, "No STAGE data", ha='center', va='center', transform=axes_runtime[0].transAxes)
        axes_runtime[0].set_title("STAGE (No data)")

    if not df_tsne.empty:
        sns.barplot(data=df_tsne, x="param_cat", y="runtime", ax=axes_runtime[1],
                    color=palette_dict["t-SNE"], order=unique_params_cat_tsne,
                    errorbar="sd", capsize=.1)
        axes_runtime[1].set_xlabel("Perplexity")
        axes_runtime[1].set_ylabel("Runtime (s)") 
        axes_runtime[1].set_title("t-SNE")
    else:
        axes_runtime[1].text(0.5, 0.5, "No t-SNE data", ha='center', va='center', transform=axes_runtime[1].transAxes)
        axes_runtime[1].set_title("t-SNE (No data)")
        
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(outdir / f"runtime_comparison_noise{args.noise_sd}_n{args.n_points}.png", dpi=300)
    plt.close(fig_runtime)

    print("Figures saved to", outdir.resolve())

if __name__ == "__main__":
    main()
