"""
RTS recovery plotting utilities (Top-tier paper quality)
RTS 回收的高质量绘图工具（顶会风格）

Functions / 函数:
- save_rts_recovery_figs: 保存三张图（综合图、均值误差、协方差误差）
"""

from __future__ import annotations

import os
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


def _set_style():
    mpl.rcParams.update(
        {
            "font.size": 11,
            "axes.titlesize": 12,
            "axes.labelsize": 11,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "figure.dpi": 300,
        }
    )


def _lollipop(ax, x, y, title, xlabel, ylabel):
    eps = 1e-16
    y = jnp.clip(y, eps, None)
    ax.set_yscale("log")
    ax.vlines(x, ymin=eps, ymax=y, color="#6c7a89", alpha=0.6, linewidth=1.1)
    # Top-tier minimal aesthetic: single accent color, thin outline for print
    ax.scatter(x, y, s=26, color="#1f77b4", edgecolor="#222222", linewidths=0.2)
    ax.set_ylim(eps, float(y.max()) * 1.3)
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, which="both", linestyle="--", linewidth=0.4, alpha=0.4)
    # Inset zoom on last quarter
    inz = inset_axes(ax, width="38%", height="60%", loc="upper right", borderpad=1.0)
    sl = slice(int(0.75 * len(x)), len(x))
    inz.set_yscale("log")
    inz.vlines(x[sl], ymin=eps, ymax=y[sl], color="#34495e", alpha=0.35, linewidth=0.8)
    inz.scatter(x[sl], y[sl], s=18, color="#1f77b4", edgecolor="#222222", linewidths=0.2)
    inz.set_ylim(eps, float(y[sl].max()) * 1.3)
    inz.set_xticks([int(x[sl.start]), int(x[-1])])
    inz.grid(True, which="both", linestyle="--", linewidth=0.3, alpha=0.3)
    inz.set_title("zoom", fontsize=9)


def _cov_bar(ax, x, y, title, xlabel):
    eps = 1e-16
    y = jnp.clip(y, eps, None)
    vals = jnp.log10(y)
    im = ax.imshow(
        vals[None, :],
        aspect="auto",
        cmap="magma",
        extent=[-0.5, len(x) - 0.5, 0.0, 1.0],
        vmin=float(vals.min()),
        vmax=float(vals.max()),
    )
    ax.set_yticks([])
    ax.set_xlim(-0.5, len(x) - 0.5)
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    return im


def save_rts_recovery_figs(mean_abs_err, cov_abs_err, grid, out_dir: str, dpi: int = 300):
    os.makedirs(out_dir, exist_ok=True)
    _set_style()

    K = int(mean_abs_err.shape[0])
    x = jnp.arange(K)

    # 1) Composite figure / 综合图
    fig = plt.figure(figsize=(11.0, 3.6))
    gs = fig.add_gridspec(1, 3, width_ratios=[2.2, 2.2, 1.2], wspace=0.25)

    ax0 = fig.add_subplot(gs[0, 0])
    _lollipop(
        ax0,
        x,
        mean_abs_err,
        r"A. $|\hat{\mu}_k - \mu_k^{\mathrm{RTS}}|$",
        "time k",
        "abs error",
    )

    ax1 = fig.add_subplot(gs[0, 1])
    im = _cov_bar(
        ax1,
        x,
        cov_abs_err,
        r"B. $\log_{10}$ – $\|\hat{\Sigma}_k - \Sigma_k^{\mathrm{RTS}}\|_{\infty}$",
        "time k",
    )
    cbar = fig.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
    cbar.set_label(r"$\log_{10}(\mathrm{abs\ error})$")

    ax2 = fig.add_subplot(gs[0, 2])
    ax2.axis("off")
    txt = (
        "RTS recovery (1D)\n\n"
        f"max |μ̂−μ^{{RTS}}|  = {float(jnp.max(mean_abs_err)):.2e}\n"
        f"max ||Σ̂−Σ^{{RTS}}||∞ = {float(jnp.max(cov_abs_err)):.2e}\n\n"
        f"K = {K}\n"
        f"grid = {grid.n_points}\n"
        f"bounds = ({grid.bounds[0]:.2f}, {grid.bounds[1]:.2f})"
    )
    ax2.text(0.02, 0.98, txt, va="top", ha="left", family="DejaVu Sans Mono", fontsize=10)
    ax2.plot([0.0, 0.0], [0, 1], color="#1f77b4", linewidth=3)

    fig.subplots_adjust(top=0.92, bottom=0.18, left=0.08, right=0.98, wspace=0.25)
    plt.savefig(os.path.join(out_dir, "rts_recovery_curves.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig)

    # 2) Standalone mean error figure / 单独均值误差图
    fig2, axm = plt.subplots(figsize=(5.5, 3.4))
    _lollipop(axm, x, mean_abs_err, r"$|\hat{\mu}_k - \mu_k^{\mathrm{RTS}}|$", "time k", "abs error")
    fig2.subplots_adjust(bottom=0.22, left=0.12, right=0.98)
    plt.savefig(os.path.join(out_dir, "rts_mean_error.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig2)

    # 3) Standalone covariance error figure / 单独协方差误差图
    fig3, axc = plt.subplots(figsize=(5.5, 1.4))
    im3 = _cov_bar(axc, x, cov_abs_err, r"$\log_{10}$ – $\|\hat{\Sigma}_k - \Sigma_k^{\mathrm{RTS}}\|_{\infty}$", "time k")
    fig3.colorbar(im3, ax=axc, fraction=0.12, pad=0.2).set_label(r"$\log_{10}(\mathrm{abs\ error})$")
    fig3.subplots_adjust(bottom=0.55, left=0.10, right=0.98)
    plt.savefig(os.path.join(out_dir, "rts_cov_error.png"), dpi=dpi, bbox_inches="tight")
    plt.close(fig3)


