"""Generate all exploration analysis plots.

Produces scatter plots, trajectory plots, heatmaps, and bar charts
for the exploration analysis report.
"""
from __future__ import annotations

from pathlib import Path
from typing import Optional

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats

from .primitive_classification import PRIMITIVES, LEARNED_PRIMITIVES


def _detect_primitives(data: dict | pd.DataFrame) -> list[str]:
    """Auto-detect which primitives have nonzero values in the data."""
    if isinstance(data, pd.DataFrame):
        cols = set(data.columns)
    else:
        cols = set()
        for v in data.values():
            cols.update(v.keys())
    # Collect all primitives that have nonzero values
    all_prims = list(dict.fromkeys(PRIMITIVES + LEARNED_PRIMITIVES))  # deduplicated, ordered
    active = []
    for p in all_prims:
        if p == "OTHER":
            continue
        key = f"{p}_per_1k_mean"
        if key not in cols:
            continue
        if isinstance(data, pd.DataFrame):
            if data[key].abs().sum() > 0:
                active.append(p)
        else:
            if any(abs(v.get(key, 0)) > 0 for v in data.values()):
                active.append(p)
    return active if active else [p for p in PRIMITIVES if p != "OTHER"]


def _save(fig, path: Path, dpi: int = 150):
    fig.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved: {path}")


# ---------------------------------------------------------------------------
# 1. KEY PLOT: Puzzle gain vs Math gain scatter
# ---------------------------------------------------------------------------

def plot_puzzle_vs_math_gain(
    df: pd.DataFrame,
    output_path: Path,
    puzzle_col: str = "puzzle_pass32_gain",
    math_col: str = "math_pass32_gain",
    color_col: Optional[str] = "sft_base",
    label_col: Optional[str] = "checkpoint_id",
):
    """Scatter plot of puzzle pass@32 gain vs math pass@32 gain.

    Each point = one GSPO checkpoint (or checkpoint x task).
    """
    if puzzle_col not in df.columns or math_col not in df.columns:
        print(f"  Skipping puzzle_vs_math_gain: missing columns ({puzzle_col} or {math_col})")
        return

    mask = df[puzzle_col].notna() & df[math_col].notna()
    plot_df = df[mask].copy()

    if plot_df.empty:
        print(f"  Skipping puzzle_vs_math_gain: no data with both gains")
        return

    fig, ax = plt.subplots(figsize=(8, 6))

    if color_col and color_col in plot_df.columns:
        for group, gdf in plot_df.groupby(color_col):
            ax.scatter(gdf[puzzle_col], gdf[math_col], label=str(group), s=60, alpha=0.8)
    else:
        ax.scatter(plot_df[puzzle_col], plot_df[math_col], s=60, alpha=0.8)

    # Add labels
    if label_col and label_col in plot_df.columns:
        for _, row in plot_df.iterrows():
            ax.annotate(
                str(row[label_col])[:25],
                (row[puzzle_col], row[math_col]),
                fontsize=7, alpha=0.7,
                xytext=(5, 5), textcoords="offset points",
            )

    # Linear fit
    x, y = plot_df[puzzle_col].values, plot_df[math_col].values
    if len(x) >= 3:
        slope, intercept, r, p, _ = stats.linregress(x, y)
        x_fit = np.linspace(x.min(), x.max(), 100)
        ax.plot(x_fit, slope * x_fit + intercept, "k--", alpha=0.5)
        rho, rho_p = stats.spearmanr(x, y)
        ax.set_title(f"Puzzle vs Math Gain (Spearman ρ={rho:.2f}, p={rho_p:.3f})")
    else:
        ax.set_title("Puzzle vs Math Gain")

    ax.set_xlabel("Puzzle pass@32 gain over SFT")
    ax.set_ylabel("Math (OlymMATH Hard) pass@32 gain over SFT")
    ax.axhline(0, color="gray", linewidth=0.5, linestyle=":")
    ax.axvline(0, color="gray", linewidth=0.5, linestyle=":")
    if color_col:
        ax.legend()
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 2. Exploration metric vs Math gain scatter grid
# ---------------------------------------------------------------------------

def plot_metric_vs_math_gain(
    df: pd.DataFrame,
    metric_cols: list[str],
    output_path: Path,
    math_col: str = "math_pass32_gain",
    ncols: int = 3,
):
    """Grid of scatter plots: each exploration metric vs math gain."""
    if math_col not in df.columns:
        print(f"  Skipping metrics_vs_math_gain: {math_col} not in table")
        return
    n_metrics = len(metric_cols)
    if n_metrics == 0:
        return

    nrows = (n_metrics + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows))
    if nrows == 1 and ncols == 1:
        axes = np.array([[axes]])
    elif nrows == 1:
        axes = axes[np.newaxis, :]
    elif ncols == 1:
        axes = axes[:, np.newaxis]

    for i, col in enumerate(metric_cols):
        r, c = divmod(i, ncols)
        ax = axes[r, c]
        mask = df[col].notna() & df[math_col].notna()
        x = df.loc[mask, col].values
        y = df.loc[mask, math_col].values

        ax.scatter(x, y, s=40, alpha=0.7)

        if len(x) >= 3:
            rho, p = stats.spearmanr(x, y)
            ax.set_title(f"{col}\nρ={rho:.2f}, p={p:.2f}", fontsize=9)
            slope, intercept, _, _, _ = stats.linregress(x, y)
            x_fit = np.linspace(x.min(), x.max(), 50)
            ax.plot(x_fit, slope * x_fit + intercept, "r--", alpha=0.5)
        else:
            ax.set_title(col, fontsize=9)

        ax.set_xlabel(col, fontsize=8)
        ax.set_ylabel(math_col, fontsize=8)
        ax.tick_params(labelsize=7)

    # Hide empty subplots
    for i in range(n_metrics, nrows * ncols):
        r, c = divmod(i, ncols)
        axes[r, c].set_visible(False)

    fig.suptitle("Exploration Metrics vs Math Gain", fontsize=12)
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 3. Math ceiling trajectory: SFT epochs vs GSPO steps
# ---------------------------------------------------------------------------

def plot_math_trajectory(
    sft_data: dict[str, float],
    gspo_data: dict[str, float],
    output_path: Path,
    title: str = "OlymMATH Hard pass@32 Trajectory",
):
    """Line plot: SFT epoch trajectory (flat) vs GSPO step trajectory.

    Args:
        sft_data: {label: pass@32_value} e.g. {"epoch2": 0.78, ...}
        gspo_data: {label: pass@32_value} e.g. {"step10": 0.82, ...}
    """
    fig, ax = plt.subplots(figsize=(8, 5))

    if sft_data:
        labels = list(sft_data.keys())
        values = list(sft_data.values())
        ax.plot(range(len(labels)), values, "o-", label="SFT (no GSPO)", color="gray")
        ax.set_xticks(range(len(labels)))
        ax.set_xticklabels(labels, rotation=45, ha="right")

    if gspo_data:
        labels = list(gspo_data.keys())
        values = list(gspo_data.values())
        offset = len(sft_data) if sft_data else 0
        x_pos = range(offset, offset + len(labels))
        ax.plot(x_pos, values, "s-", label="GSPO on puzzles", color="blue")
        all_labels = list(sft_data.keys()) + labels if sft_data else labels
        ax.set_xticks(range(len(all_labels)))
        ax.set_xticklabels(all_labels, rotation=45, ha="right")

    ax.set_ylabel("OlymMATH Hard pass@32")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 4. Primitive distribution heatmap
# ---------------------------------------------------------------------------

def plot_primitive_distribution(
    checkpoint_data: dict[str, dict[str, float]],
    output_path: Path,
):
    """Heatmap of primitive rates (per 1k tokens) across checkpoints.

    Args:
        checkpoint_data: {ckpt_id: {primitive_per_1k_mean: float, ...}}
    """
    prims = _detect_primitives(checkpoint_data)
    ckpts = sorted(checkpoint_data.keys())

    data = np.zeros((len(ckpts), len(prims)))
    for i, ckpt in enumerate(ckpts):
        for j, prim in enumerate(prims):
            key = f"{prim}_per_1k_mean"
            data[i, j] = checkpoint_data[ckpt].get(key, 0.0)

    fig, ax = plt.subplots(figsize=(10, max(3, len(ckpts) * 0.6)))
    sns.heatmap(data, ax=ax, xticklabels=prims, yticklabels=ckpts,
                annot=True, fmt=".2f", cmap="YlOrRd")
    ax.set_title("Primitive Rates per 1k Tokens")
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 5. Transition matrix heatmap
# ---------------------------------------------------------------------------

def plot_transition_matrices(
    matrices: dict[str, np.ndarray],
    output_path: Path,
):
    """Side-by-side transition matrix heatmaps.

    Args:
        matrices: {ckpt_label: 10x10 transition matrix}
    """
    n = len(matrices)
    if n == 0:
        return

    fig, axes = plt.subplots(1, n, figsize=(6 * n, 5))
    if n == 1:
        axes = [axes]

    # Infer labels from matrix size
    first_mat = next(iter(matrices.values()))
    n_labels = first_mat.shape[0]
    labels = (LEARNED_PRIMITIVES if n_labels == len(LEARNED_PRIMITIVES)
              else PRIMITIVES)[:n_labels]
    for ax, (ckpt, mat) in zip(axes, matrices.items()):
        sns.heatmap(mat, ax=ax, xticklabels=labels, yticklabels=labels,
                    annot=True, fmt=".2f", cmap="Blues", vmin=0, vmax=0.5)
        ax.set_title(ckpt, fontsize=10)
        ax.tick_params(labelsize=7)

    fig.suptitle("Primitive Transition Matrices", fontsize=12)
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 6. Diversity bar chart
# ---------------------------------------------------------------------------

def plot_diversity_comparison(
    checkpoint_data: dict[str, dict],
    output_path: Path,
    metrics: list[str] = ("effective_num_paths_mean", "cluster_entropy_mean",
                          "successful_effective_num_paths_mean"),
):
    """Bar chart comparing diversity metrics across checkpoints."""
    ckpts = sorted(checkpoint_data.keys())
    n_metrics = len(metrics)

    fig, ax = plt.subplots(figsize=(max(6, len(ckpts) * 1.5), 5))
    x = np.arange(len(ckpts))
    width = 0.8 / n_metrics

    for i, metric in enumerate(metrics):
        values = [checkpoint_data[c].get(metric, 0.0) for c in ckpts]
        ax.bar(x + i * width, values, width, label=metric.replace("_mean", ""))

    ax.set_xticks(x + width * (n_metrics - 1) / 2)
    ax.set_xticklabels(ckpts, rotation=45, ha="right")
    ax.set_ylabel("Value")
    ax.set_title("Diversity Metrics by Checkpoint")
    ax.legend(fontsize=8)
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# 7. Primitive diff: GSPO-minus-SFT
# ---------------------------------------------------------------------------

def plot_primitive_diff(
    sft_metrics: dict[str, float],
    gspo_metrics: dict[str, float],
    output_path: Path,
    title: str = "Primitive Count Change: GSPO - SFT",
    metric_suffix: str = "_count_mean",
):
    """Bar chart showing which primitives increase/decrease after GSPO.

    Args:
        metric_suffix: key suffix to look up. Default '_count_mean' (per-trace counts).
            Falls back to '_per_1k_mean' if count keys not found.
    """
    prims = _detect_primitives({
        "sft": sft_metrics, "gspo": gspo_metrics,
    })
    # Auto-detect available metric suffix
    if not any(f"{p}{metric_suffix}" in sft_metrics for p in prims):
        metric_suffix = "_per_1k_mean"
    diffs = []
    for p in prims:
        key = f"{p}{metric_suffix}"
        sft_val = sft_metrics.get(key, 0.0)
        gspo_val = gspo_metrics.get(key, 0.0)
        diffs.append(gspo_val - sft_val)

    ylabel = ("Δ episodes per trace (GSPO - SFT)" if "count" in metric_suffix
              else "Δ rate per 1k tokens (GSPO - SFT)")

    fig, ax = plt.subplots(figsize=(10, 5))
    colors = ["green" if d > 0 else "red" for d in diffs]
    ax.bar(prims, diffs, color=colors, alpha=0.7)
    ax.axhline(0, color="black", linewidth=0.5)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.tick_params(axis="x", rotation=45)
    fig.tight_layout()
    _save(fig, output_path)


def plot_primitive_comparison_per_task(
    trace_df: pd.DataFrame,
    output_path: Path,
    sft_id: str = "dsr_sft_v2",
    gspo_id: str = "gspo_v2_sft_step20",
    primitives: list[str] | None = None,
):
    """Side-by-side bar chart of primitive % per task, shared y-axis."""
    if primitives is None:
        all_prims = ["PLAN", "ENUMERATE", "HYPOTHESIZE", "COMPUTE", "BACKTRACK", "SUMMARIZE", "OTHER"]
        primitives = [p for p in all_prims
                      if f"{p}_count" in trace_df.columns and trace_df[f"{p}_count"].sum() > 0]

    tasks = sorted(trace_df.task_name.unique())
    fig, axes = plt.subplots(1, len(tasks), figsize=(6 * len(tasks), 5), sharey=True)
    if len(tasks) == 1:
        axes = [axes]

    for ax, task in zip(axes, tasks):
        sft = trace_df[(trace_df.checkpoint_id == sft_id) & (trace_df.task_name == task)]
        gspo = trace_df[(trace_df.checkpoint_id == gspo_id) & (trace_df.task_name == task)]

        s_tot = sum(sft[f"{p}_count"].mean() for p in primitives)
        g_tot = sum(gspo[f"{p}_count"].mean() for p in primitives)

        s_pcts = [100 * sft[f"{p}_count"].mean() / s_tot if s_tot > 0 else 0 for p in primitives]
        g_pcts = [100 * gspo[f"{p}_count"].mean() / g_tot if g_tot > 0 else 0 for p in primitives]

        x = np.arange(len(primitives))
        w = 0.35
        ax.bar(x - w / 2, s_pcts, w, label="SFT", color="steelblue", alpha=0.8)
        ax.bar(x + w / 2, g_pcts, w, label="GSPO", color="coral", alpha=0.8)
        ax.set_xticks(x)
        ax.set_xticklabels(primitives, rotation=45, ha="right", fontsize=8)
        ax.set_title(task.replace("_pass32", ""), fontsize=10)
        ax.legend(fontsize=8)

    axes[0].set_ylabel("% of primitives")
    fig.suptitle("Primitive Distribution: SFT vs GSPO", fontsize=12)
    fig.tight_layout()
    _save(fig, output_path)


# ---------------------------------------------------------------------------
# Generate all plots
# ---------------------------------------------------------------------------

def generate_all_plots(
    checkpoint_table: pd.DataFrame,
    checkpoint_metrics: dict[str, dict],
    output_dir: Path,
    sft_trajectory: Optional[dict[str, float]] = None,
    gspo_trajectory: Optional[dict[str, float]] = None,
    transition_matrices: Optional[dict[str, np.ndarray]] = None,
):
    """Generate all standard plots and save to output_dir/plots/."""
    plots_dir = output_dir / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)

    # 1. Puzzle gain vs math gain
    plot_puzzle_vs_math_gain(
        checkpoint_table, plots_dir / "puzzle_vs_math_gain.png"
    )

    # 2. Exploration metrics vs math gain
    metric_cols = [c for c in checkpoint_table.columns
                   if c.endswith("_mean") and "pass" not in c and "gain" not in c]
    if metric_cols:
        plot_metric_vs_math_gain(
            checkpoint_table, metric_cols[:12],
            plots_dir / "metrics_vs_math_gain.png"
        )

    # 3. Math trajectory
    if sft_trajectory or gspo_trajectory:
        plot_math_trajectory(
            sft_trajectory or {}, gspo_trajectory or {},
            plots_dir / "math_trajectory.png"
        )

    # 4. Primitive distribution
    if checkpoint_metrics:
        plot_primitive_distribution(
            checkpoint_metrics, plots_dir / "primitive_distribution.png"
        )

    # 5. Transition matrices
    if transition_matrices:
        plot_transition_matrices(
            transition_matrices, plots_dir / "transition_matrices.png"
        )

    # 6. Diversity comparison
    if checkpoint_metrics:
        plot_diversity_comparison(
            checkpoint_metrics, plots_dir / "diversity_comparison.png"
        )

    # 7. Primitive diff (SFT vs first GSPO checkpoint found)
    ckpt_ids = sorted(checkpoint_metrics.keys()) if checkpoint_metrics else []
    sft_ckpts = [c for c in ckpt_ids if "sft" in c.lower()]
    gspo_ckpts = [c for c in ckpt_ids if "gspo" in c.lower() or "step" in c.lower()]
    if sft_ckpts and gspo_ckpts:
        plot_primitive_diff(
            checkpoint_metrics[sft_ckpts[0]],
            checkpoint_metrics[gspo_ckpts[0]],
            plots_dir / "primitive_diff.png",
        )
