﻿from __future__ import annotations

import argparse
import json
from pathlib import Path

import numpy as np


def _apply_plot_style() -> None:
    import matplotlib as mpl

    mpl.rcParams.update(
        {
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial", "DejaVu Sans"],
            "font.size": 8.5,
            "axes.titlesize": 8.5,
            "axes.labelsize": 8.5,
            "xtick.labelsize": 8,
            "ytick.labelsize": 8,
            "legend.fontsize": 7.5,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "axes.linewidth": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "mathtext.fontset": "stixsans",
        }
    )


def plot_tbptt_accuracy(src: Path, out_dir: Path) -> None:
    import matplotlib.pyplot as plt

    obj = json.loads(src.read_text(encoding="utf-8"))
    summary = obj["summary"]
    tbptt = summary["tbptt_acc_by_k"]
    bptt = summary["bptt_acc"]

    ks = sorted(int(k) for k in tbptt.keys())
    means = [tbptt[str(k)]["mean"] for k in ks]
    lows = [tbptt[str(k)]["lo"] for k in ks]
    highs = [tbptt[str(k)]["hi"] for k in ks]

    yerr = [np.array(means) - np.array(lows), np.array(highs) - np.array(means)]

    _apply_plot_style()
    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    ax.errorbar(ks, means, yerr=yerr, marker="o", markersize=3.5, linewidth=1.1, capsize=2.5, label="TBPTT")
    ax.axhline(bptt["mean"], color="black", linestyle="--", linewidth=1.0, label="Full BPTT")
    ax.set_xscale("log", base=2)
    ax.set_xticks(ks)
    ax.set_xlabel(r"TBPTT truncation $K$")
    ax.set_ylabel("Test accuracy")
    ax.grid(True, linestyle=":", alpha=0.3)
    ax.legend(frameon=False, ncol=1)
    fig.savefig(out_dir / "row_mnist_tbptt_accuracy_vs_k.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "row_mnist_tbptt_accuracy_vs_k.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def plot_alignment(src: Path, out_dir: Path) -> None:
    import matplotlib.pyplot as plt

    obj = json.loads(src.read_text(encoding="utf-8"))
    summary = obj["summary"]
    ks = summary["tbptt_ks"]

    def _extract(key: str):
        data = summary[key]
        means = [data[str(k)]["mean"] for k in ks]
        lows = [data[str(k)]["lo"] for k in ks]
        highs = [data[str(k)]["hi"] for k in ks]
        full = data.get("full", None)
        return means, lows, highs, full

    delta_mean, delta_lo, delta_hi, delta_full = _extract("current_delta_cos")
    grad_mean, grad_lo, grad_hi, grad_full = _extract("current_grad_cos_rec")

    _apply_plot_style()
    fig, axes = plt.subplots(1, 2, figsize=(6.3, 2.6), constrained_layout=True, sharey=True)

    axes[0].plot(ks, delta_mean, marker="o", markersize=3.5, linewidth=1.1, label="COLA")
    axes[0].fill_between(ks, delta_lo, delta_hi, alpha=0.18)
    if delta_full is not None:
        axes[0].axhline(delta_full["mean"], color="black", linestyle="--", linewidth=1.0, label="Full BPTT")

    axes[1].plot(ks, grad_mean, marker="o", markersize=3.5, linewidth=1.1, label="COLA")
    axes[1].fill_between(ks, grad_lo, grad_hi, alpha=0.18)
    if grad_full is not None:
        axes[1].axhline(grad_full["mean"], color="black", linestyle="--", linewidth=1.0, label="Full BPTT")

    for ax in axes:
        ax.set_xscale("log", base=2)
        ax.set_xticks(ks)
        ax.grid(True, linestyle=":", alpha=0.3)
        ax.set_ylim(0.0, 1.0)

    axes[0].set_title(r"Teaching signal ($\delta$)")
    axes[1].set_title(r"Recurrent update ($\partial \mathcal{L}/\partial W_{hh}$)")

    if hasattr(fig, "supxlabel"):
        fig.supxlabel(r"TBPTT truncation $K$")
        fig.supylabel(r"Alignment $|\cos|$")
    else:
        fig.text(0.5, -0.02, r"TBPTT truncation $K$", ha="center")
        fig.text(-0.02, 0.5, r"Alignment $|\cos|$", va="center", rotation="vertical")

    axes[0].legend(frameon=False, loc="lower left")
    axes[1].legend(frameon=False, loc="lower left")
    fig.savefig(out_dir / "row_mnist_alignment_vs_k.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "row_mnist_alignment_vs_k.png", dpi=240, bbox_inches="tight")
    plt.close(fig)

    def _plot_single(
        mean: list[float],
        lo: list[float],
        hi: list[float],
        full: dict | None,
        title: str,
        out_name: str,
    ) -> None:
        _apply_plot_style()
        fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
        ax.plot(ks, mean, marker="o", markersize=3.5, linewidth=1.1)
        ax.fill_between(ks, lo, hi, alpha=0.18)
        if full is not None:
            ax.axhline(full["mean"], color="black", linestyle="--", linewidth=1.0, label="Full BPTT")
        ax.set_title(title)
        ax.set_xscale("log", base=2)
        ax.set_xticks(ks)
        ax.set_xlabel(r"TBPTT truncation $K$")
        ax.set_ylabel(r"Alignment $|\cos|$")
        ax.set_ylim(0.0, 1.0)
        ax.grid(True, linestyle=":", alpha=0.3)
        if full is not None:
            ax.legend(frameon=False, loc="lower left")
        fig.savefig(out_dir / f"{out_name}.pdf", bbox_inches="tight")
        fig.savefig(out_dir / f"{out_name}.png", dpi=240, bbox_inches="tight")
        plt.close(fig)

    _plot_single(
        delta_mean,
        delta_lo,
        delta_hi,
        delta_full,
        title="Delta",
        out_name="row_mnist_alignment_delta_vs_k",
    )
    _plot_single(
        grad_mean,
        grad_lo,
        grad_hi,
        grad_full,
        title="Recurrent dW",
        out_name="row_mnist_alignment_dw_vs_k",
    )


def plot_compute_matched(src: Path, out_dir: Path) -> None:
    import matplotlib.pyplot as plt

    obj = json.loads(src.read_text(encoding="utf-8"))
    curves = obj["curves_time"]

    _apply_plot_style()
    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)

    preferred = ["Local Rule", "BPTT", "TBPTT-1", "TBPTT-4", "TBPTT-16"]
    methods = [m for m in preferred if m in curves]
    if not methods:
        methods = list(curves.keys())[:5]

    style = {
        "Local Rule": {"label": "COLA", "color": "#d62728", "ls": "-", "lw": 1.2},
        "BPTT": {"label": "BPTT", "color": "#000000", "ls": "-", "lw": 1.2},
        "TBPTT-1": {"label": "TBPTT-1", "color": "#1f77b4", "ls": "--", "lw": 1.1},
        "TBPTT-4": {"label": "TBPTT-4", "color": "#2ca02c", "ls": "--", "lw": 1.1},
        "TBPTT-16": {"label": "TBPTT-16", "color": "#9467bd", "ls": "--", "lw": 1.1},
    }

    y_lo = float("inf")
    y_hi = float("-inf")

    fallback_colors = plt.cm.tab10(np.linspace(0.0, 0.9, len(methods)))
    for idx, name in enumerate(methods):
        curve = curves[name]
        t = np.array(curve["t_grid"], dtype=float)
        mean = np.array(curve["mean"], dtype=float)
        lo = np.array(curve["lo"], dtype=float)
        hi = np.array(curve["hi"], dtype=float)
        cfg = style.get(name, {"label": name, "color": fallback_colors[idx], "ls": "-", "lw": 1.1})
        color = cfg["color"] if cfg.get("color") is not None else fallback_colors[idx]
        ax.plot(t, mean, linewidth=cfg["lw"], linestyle=cfg["ls"], label=cfg["label"], color=color)
        ax.fill_between(t, lo, hi, alpha=0.12, color=color)
        y_lo = min(y_lo, float(np.min(lo)))
        y_hi = max(y_hi, float(np.max(hi)))

    ax.set_xlabel("Wall-clock time (s)")
    ax.set_ylabel("Test accuracy")
    ax.grid(True, linestyle=":", alpha=0.3)
    if np.isfinite(y_lo) and np.isfinite(y_hi):
        pad = 0.02
        ax.set_ylim(max(0.0, y_lo - pad), min(1.0, y_hi + pad))
    ax.legend(frameon=False, ncol=1)
    fig.savefig(out_dir / "row_mnist_compute_matched_time.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "row_mnist_compute_matched_time.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def plot_update_matched(src: Path, out_dir: Path) -> None:
    import matplotlib.pyplot as plt

    obj = json.loads(src.read_text(encoding="utf-8"))
    curves = obj["curves_updates"]

    _apply_plot_style()
    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)

    preferred = ["Local Rule", "BPTT", "TBPTT-1", "TBPTT-4", "TBPTT-16"]
    methods = [m for m in preferred if m in curves]
    if not methods:
        methods = list(curves.keys())[:5]

    style = {
        "Local Rule": {"label": "COLA", "color": "#d62728", "ls": "-", "lw": 1.2},
        "BPTT": {"label": "BPTT", "color": "#000000", "ls": "-", "lw": 1.2},
        "TBPTT-1": {"label": "TBPTT-1", "color": "#1f77b4", "ls": "--", "lw": 1.1},
        "TBPTT-4": {"label": "TBPTT-4", "color": "#2ca02c", "ls": "--", "lw": 1.1},
        "TBPTT-16": {"label": "TBPTT-16", "color": "#9467bd", "ls": "--", "lw": 1.1},
    }

    y_lo = float("inf")
    y_hi = float("-inf")

    fallback_colors = plt.cm.tab10(np.linspace(0.0, 0.9, len(methods)))
    for idx, name in enumerate(methods):
        curve = curves[name]
        u = np.array(curve["u_grid"], dtype=float)
        mean = np.array(curve["mean"], dtype=float)
        lo = np.array(curve["lo"], dtype=float)
        hi = np.array(curve["hi"], dtype=float)
        cfg = style.get(name, {"label": name, "color": fallback_colors[idx], "ls": "-", "lw": 1.1})
        color = cfg["color"] if cfg.get("color") is not None else fallback_colors[idx]
        ax.plot(u, mean, linewidth=cfg["lw"], linestyle=cfg["ls"], label=cfg["label"], color=color)
        ax.fill_between(u, lo, hi, alpha=0.12, color=color)
        y_lo = min(y_lo, float(np.min(lo)))
        y_hi = max(y_hi, float(np.max(hi)))

    ax.set_xlabel("Optimization updates")
    ax.set_ylabel("Test accuracy")
    ax.grid(True, linestyle=":", alpha=0.3)
    if np.isfinite(y_lo) and np.isfinite(y_hi):
        pad = 0.02
        ax.set_ylim(max(0.0, y_lo - pad), min(1.0, y_hi + pad))
    ax.legend(frameon=False, ncol=1)
    fig.savefig(out_dir / "row_mnist_accuracy_vs_updates.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "row_mnist_accuracy_vs_updates.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def plot_effective_rank(src: Path, out_dir: Path) -> None:
    import matplotlib.pyplot as plt

    obj = json.loads(src.read_text(encoding="utf-8"))
    summary = obj["summary"]
    by_k = summary["delta_true_eff_rank_participation_by_k"]
    ks = sorted(int(k) for k in by_k.keys())
    means = [by_k[str(k)]["mean"] for k in ks]
    lows = [by_k[str(k)]["lo"] for k in ks]
    highs = [by_k[str(k)]["hi"] for k in ks]

    yerr = [np.array(means) - np.array(lows), np.array(highs) - np.array(means)]

    _apply_plot_style()
    fig, ax = plt.subplots(1, 1, figsize=(3.35, 2.35), constrained_layout=True)
    ax.errorbar(ks, means, yerr=yerr, marker="o", markersize=3.5, linewidth=1.1, capsize=2.5)
    ax.set_xscale("log", base=2)
    ax.set_xticks(ks)
    ax.set_xlabel(r"Evaluation window $K$")
    ax.set_ylabel(r"Effective rank of $\delta^{true}$")
    ax.grid(True, linestyle=":", alpha=0.3)
    fig.savefig(out_dir / "row_mnist_effective_rank_vs_k.pdf", bbox_inches="tight")
    fig.savefig(out_dir / "row_mnist_effective_rank_vs_k.png", dpi=240, bbox_inches="tight")
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Plot Row-MNIST diagnostics (paper style).")
    parser.add_argument("--tbptt-json", type=str, required=True)
    parser.add_argument("--align-json", type=str, required=True)
    parser.add_argument("--compute-json", type=str, required=True)
    parser.add_argument("--eff-rank-json", type=str, default=None)
    parser.add_argument("--out-dir", type=str, required=True)
    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    plot_tbptt_accuracy(Path(args.tbptt_json), out_dir)
    plot_alignment(Path(args.align_json), out_dir)
    plot_compute_matched(Path(args.compute_json), out_dir)
    plot_update_matched(Path(args.compute_json), out_dir)
    if args.eff_rank_json:
        plot_effective_rank(Path(args.eff_rank_json), out_dir)


if __name__ == "__main__":
    main()
