import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D
from typing import Any
from scipy.stats import binned_statistic_2d
from scipy.interpolate import griddata

def plot_l63_full_single(trajectory: np.ndarray[tuple[int, int], Any], **kwargs):
    T = trajectory.shape[0]
    x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
    fig = plt.figure(figsize=(12, 7))
    gs = fig.add_gridspec(3, 2, hspace=0.3)

    ax3d = fig.add_subplot(gs[:, 0], projection='3d')
    ax3d.plot(x, y, z, color='steelblue', lw=0.5)
    ax3d.scatter(x[0], y[0], z[0], color='red', marker='x', s=50, label='Start')
    ax3d.scatter(x[-1], y[-1], z[-1], color='blue', marker='o', s=100, label='End')
    ax3d.set_xlabel("X")
    ax3d.set_ylabel("Y")
    ax3d.set_zlabel("Z")

    plotted_var = kwargs.get('plotted_variable', 'Trajectory')

    ax3d.legend()

    if 's_comp_labels' in kwargs:
        labels = kwargs['s_comp_labels']
    elif 'u_comp_labels' in kwargs:
        labels =  kwargs['u_comp_labels']
    else:
        labels = [r'comp_1(t)', r'comp_2(t)', r'comp_3(t)']

    colors = ['tab:cyan', 'tab:pink', 'tab:orange']

    for i, (coord, label, color) in enumerate(zip([x, y, z], labels, colors)):
        ax = fig.add_subplot(gs[i, 1])
        ax.plot(np.arange(T), coord, color=color, lw=1)
        ax.set_ylabel(label)
        ax.grid(True, ls='--', alpha=0.6)
        if i < 2:
            ax.tick_params(labelbottom=False)
        ax.set_xlabel("Time")

    fig.suptitle(f"{plotted_var} Visualization", fontsize=16, y=0.95)
    return fig

def plot_l63_full_double(
    traj1: np.ndarray[tuple[int, int], Any],
    traj2: np.ndarray[tuple[int, int], Any],
    **kwargs,
):
    """Plot two Lorenz-63 trajectories in 3D and component-wise 1D plots."""
    assert traj1.shape[1] == 3 and traj2.shape[1] == 3, "Each trajectory must be (T, 3)."

    T1, T2 = traj1.shape[0], traj2.shape[0]
    x1, y1, z1 = traj1[:, 0], traj1[:, 1], traj1[:, 2]
    x2, y2, z2 = traj2[:, 0], traj2[:, 1], traj2[:, 2]

    fig = plt.figure(figsize=(12, 7))
    gs = fig.add_gridspec(3, 2, hspace=0.3)

    plotted_var_1 = kwargs.get("plotted_variable_1", "Trajectory 1")
    plotted_var_2 = kwargs.get("plotted_variable_2", "Trajectory 2")

    # --- 3D phase-space plot ---
    ax3d = fig.add_subplot(gs[:, 0], projection="3d")
    ax3d.plot(x1, y1, z1, color="steelblue", lw=0.8, label=plotted_var_1)
    ax3d.plot(x2, y2, z2, color="darkorange", lw=0.8, label=plotted_var_2)

    ax3d.scatter(x1[0], y1[0], z1[0], color="red", marker="x", s=40)
    ax3d.scatter(x2[0], y2[0], z2[0], color="darkred", marker="x", s=40)
    ax3d.scatter(x1[-1], y1[-1], z1[-1], color="blue", marker="o", s=60)
    ax3d.scatter(x2[-1], y2[-1], z2[-1], color="navy", marker="o", s=60)

    ax3d.set_xlabel("X")
    ax3d.set_ylabel("Y")
    ax3d.set_zlabel("Z")

    ax3d.legend()

    # --- Component-wise time series ---
    if "first" in kwargs and "second" in kwargs:
        labels_1 = kwargs["first"]
        labels_2 = kwargs["second"]
    else:
        labels_1 = ["comp_1(t)", "comp_2(t)", "comp_3(t)"]
        labels_2 = ["comp_1(t)", "comp_2(t)", "comp_3(t)"]

    colors1 = ["tab:cyan", "tab:pink", "tab:orange"]

    for i, (coord1, coord2, label_1, label_2, c1) in enumerate(
        zip([x1, y1, z1], [x2, y2, z2], labels_1, labels_2, colors1)
    ):
        ax = fig.add_subplot(gs[i, 1])
        ax.plot(np.arange(T1), coord1, color=c1, linestyle='-', lw=1, label=label_1)
        ax.plot(np.arange(T2), coord2, color=c1, linestyle='--', lw=1, alpha=0.8, label=label_2)
        ax.set_ylabel(f"Dim {i+1}")
        ax.grid(True, ls="--", alpha=0.6)
        if i < 2:
            ax.tick_params(labelbottom=False)
        ax.set_xlabel("Time")
        ax.legend(loc="upper right", fontsize=8)

    fig.suptitle(f"{plotted_var_1} vs {plotted_var_2}", fontsize=16, y=0.95)

    return fig

def plot_summary_stats_comparison(
        s_true: np.ndarray, 
        s_pred: np.ndarray, 
        title: str
    ):
    """
    Plot true vs predicted summary statistics.
    
    Args:
        s_true: True summary statistics, shape [D,] or [T, D]
        s_pred: Predicted summary statistics, same shape as s_true
        title: Plot title
    
    Returns:
        matplotlib Figure object for wandb logging
    """
    if s_true.ndim == 1:
        s_true = s_true.reshape(-1, 1)
        s_pred = s_pred.reshape(-1, 1)
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 5), squeeze=False)
    ax = ax[0].flatten()
    
    indices = np.arange(len(s_true))
    
    ax[0].plot(indices, s_true, 'o-', color='steelblue', 
            label='True', markersize=4, linewidth=.5, alpha=0.8)
    ax[0].plot(indices, s_pred, 's--', color='coral', 
            label='Predicted', markersize=4, linewidth=.5, alpha=0.8)
    
    ax[0].set_ylabel(f'1D Statistic', fontsize=10)
    ax[0].grid(True, ls='--', alpha=0.4)
    ax[0].legend()
    
    mse = np.mean((s_true - s_pred)**2)
    spectral_distance = np.linalg.norm(np.fft.fft(s_true) - np.fft.fft(s_pred))
    ax[0].text(0.02, 0.9, f'Spectral Dist: {spectral_distance:.4f}', transform=ax[0].transAxes, 
            fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    ax[0].text(0.02, 0.95, f'MSE: {mse:.4e}', transform=ax[0].transAxes,
            fontsize=9, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

    ax[0].set_xlabel("Index", fontsize=10)
    plt.suptitle(title, fontsize=14, y=0.995)
    plt.tight_layout()
    
    return fig

def _as_1d_timeseries(x: np.ndarray, name: str) -> np.ndarray:
    x = np.asarray(x)

    if x.ndim == 2 and x.shape[1] == 1:
        return x[:, 0]
    if x.ndim == 1:
        return x

    raise ValueError(f"{name} must have shape (T,) or (T,1), got {x.shape}")

def _normalize(x: np.ndarray) -> np.ndarray:
    return (x - x.mean()) / (x.std() + 1e-8)

def _pearsonr_safe(x: np.ndarray, y: np.ndarray) -> float:
    x = np.asarray(x)
    y = np.asarray(y)
    if x.size == 0 or y.size == 0:
        return float("nan")
    if np.std(x) == 0.0 or np.std(y) == 0.0:
        return float("nan")
    return float(np.corrcoef(x, y)[0, 1])

def plot_l63_regime_identification(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    max_t: int | None = None,
    regime_threshold: float = 0.0,
):
    """
    Regime identification plot using x(t) sign as regime indicator.
    Top: x(t). Bottom: s_true(t) with horizontal lines for regime means and boundary.
    """
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")
    T = u_true.shape[0]
    if s_t.shape[0] != T:
        raise ValueError(
            f"Time dimension mismatch: u_true has T={T}, s_true={s_t.shape[0]}"
        )

    if max_t is not None:
        u_true = u_true[:max_t]
        s_t = s_t[:max_t]
        T = len(u_true)

    t = np.arange(T)
    x = u_true[:, 0]

    fig, axs = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
    axs[0].plot(t, x, lw=0.8, color="tab:blue")
    axs[0].axhline(regime_threshold, color="k", ls="--", lw=0.8, alpha=0.6)
    axs[0].set_title("Regime indicator: x(t)")
    axs[0].set_ylabel("x(t)")

    axs[1].plot(t, s_t, lw=0.8, color="tab:purple", label="s_true")
    mask_pos = x > regime_threshold
    mask_neg = ~mask_pos
    if mask_pos.any() and mask_neg.any():
        mean_pos = float(np.mean(s_t[mask_pos]))
        mean_neg = float(np.mean(s_t[mask_neg]))
        boundary = 0.5 * (mean_pos + mean_neg)
        axs[1].axhline(mean_pos, color="tab:green", ls="-", lw=1.0, label="mean x>0")
        axs[1].axhline(mean_neg, color="tab:orange", ls="-", lw=1.0, label="mean x<0")
        axs[1].axhline(boundary, color="k", ls="--", lw=0.8, alpha=0.6, label="boundary")
    axs[1].set_title("Summary vs regime")
    axs[1].set_ylabel("s_true")
    axs[1].set_xlabel("Time")
    axs[1].legend(frameon=False, fontsize=8)

    plt.tight_layout()
    return fig

def plot_summary_vs_state(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    max_points: int | None = 5000,
):
    """
    Scatter plots of s_true vs x, y, z, and energy.
    """
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")
    T = u_true.shape[0]
    if s_t.shape[0] != T:
        raise ValueError(
            f"Time dimension mismatch: u_true has T={T}, s_true={s_t.shape[0]}"
        )

    if max_points is not None and T > max_points:
        idx = np.linspace(0, T - 1, max_points).astype(int)
        u_true = u_true[idx]
        s_t = s_t[idx]

    x = u_true[:, 0]
    y = u_true[:, 1]
    z = u_true[:, 2]
    energy = x ** 2 + y ** 2 + z ** 2

    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    panels = [
        ("x", x),
        ("y", y),
        ("z", z),
        ("energy", energy),
    ]

    for ax, (label, vals) in zip(axs.flat, panels):
        r = _pearsonr_safe(s_t, vals)
        ax.scatter(vals, s_t, s=6, alpha=0.4)
        r_str = f"{r:.2f}" if np.isfinite(r) else "n/a"
        ax.set_title(f"s_true vs {label} (r={r_str})")
        ax.set_xlabel(label)
        ax.set_ylabel("s_true")
        ax.grid(True, ls="--", alpha=0.3)

    plt.tight_layout()
    return fig

def plot_l63_projections_colored_by_summary(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    max_points: int | None = 10000,
    cmap: str = "viridis",
    cbar_label: str = "s_true",
    vmin: float | None = None,
    vmax: float | None = None,
):
    """
    Three 2D projections (x,y), (x,z), (y,z) colored by the summary with a shared scale.
    """
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")
    T = u_true.shape[0]
    if s_t.shape[0] != T:
        raise ValueError(
            f"Time dimension mismatch: u_true has T={T}, s_true={s_t.shape[0]}"
        )

    if max_points is not None and T > max_points:
        idx = np.linspace(0, T - 1, max_points).astype(int)
        u_true = u_true[idx]
        s_t = s_t[idx]

    x = u_true[:, 0]
    y = u_true[:, 1]
    z = u_true[:, 2]

    if vmin is None:
        vmin = float(np.min(s_t))
    if vmax is None:
        vmax = float(np.max(s_t))

    fig, axs = plt.subplots(
        1, 3, figsize=(12, 4), sharex=False, sharey=False, constrained_layout=True
    )
    sc0 = axs[0].scatter(x, y, c=s_t, s=6, cmap=cmap, vmin=vmin, vmax=vmax)
    axs[0].set_xlabel("x")
    axs[0].set_ylabel("y")
    axs[0].set_title("(x, y)")

    axs[1].scatter(x, z, c=s_t, s=6, cmap=cmap, vmin=vmin, vmax=vmax)
    axs[1].set_xlabel("x")
    axs[1].set_ylabel("z")
    axs[1].set_title("(x, z)")

    axs[2].scatter(y, z, c=s_t, s=6, cmap=cmap, vmin=vmin, vmax=vmax)
    axs[2].set_xlabel("y")
    axs[2].set_ylabel("z")
    axs[2].set_title("(y, z)")

    fig.colorbar(sc0, ax=axs, fraction=0.03, pad=0.02, label=cbar_label)
    return fig


def plot_l63_discriminator_phase_space(
    u_true: np.ndarray,   # (T, 3)
    u_fake: np.ndarray,   # (T, 3)
    phi_true: np.ndarray, # (T,)
    phi_fake: np.ndarray, # (T,)
    *,
    early_frac: float = 0.2,
    late_frac: float = 0.2,
    max_points: int = 8000,
):
    """
    Phase-space discrimination plot for critic output φ.
    Shows x-z projection with point size and color encoding φ.
    """
    u_true = np.asarray(u_true)
    u_fake = np.asarray(u_fake)
    phi_true = np.asarray(phi_true).reshape(-1)
    phi_fake = np.asarray(phi_fake).reshape(-1)

    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")
    if u_fake.ndim != 2 or u_fake.shape[1] != 3:
        raise ValueError(f"u_fake must have shape (T,3), got {u_fake.shape}")

    T_fake = u_fake.shape[0]
    early_end = max(1, int(T_fake * early_frac))
    late_start = max(0, int(T_fake * (1.0 - late_frac)))

    def _sample(u, phi, max_points):
        if u.shape[0] <= max_points:
            return u, phi
        idx = np.linspace(0, u.shape[0] - 1, max_points).astype(int)
        return u[idx], phi[idx]

    def _scale_sizes(phi, vmin, vmax):
        denom = (vmax - vmin) + 1e-8
        norm = (phi - vmin) / denom
        return 10.0 + 50.0 * np.clip(norm, 0.0, 1.0)

    phi_min = float(np.min([np.min(phi_true), np.min(phi_fake)]))
    phi_max = float(np.max([np.max(phi_true), np.max(phi_fake)]))

    u_true_s, phi_true_s = _sample(u_true, phi_true, max_points)
    u_fake_early, phi_fake_early = _sample(u_fake[:early_end], phi_fake[:early_end], max_points)
    u_fake_late, phi_fake_late = _sample(u_fake[late_start:], phi_fake[late_start:], max_points)

    fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True)
    cmap = "viridis"
    norm = plt.Normalize(vmin=phi_min, vmax=phi_max)

    panels = [
        ("Real", u_true_s, phi_true_s),
        ("Fake (early)", u_fake_early, phi_fake_early),
        ("Fake (late)", u_fake_late, phi_fake_late),
    ]

    scatter_handles = []
    for ax, (title, u_pts, phi_vals) in zip(axs, panels):
        x = u_pts[:, 0]
        z = u_pts[:, 2]
        sizes = _scale_sizes(phi_vals, phi_min, phi_max)
        sc = ax.scatter(x, z, c=phi_vals, s=sizes, cmap=cmap, norm=norm, alpha=0.6, linewidth=0)
        scatter_handles.append(sc)
        ax.set_title(title)
        ax.set_xlabel("x")
        ax.set_ylabel("z")
        ax.grid(True, ls="--", alpha=0.3)

        if "Fake" in title and len(phi_vals) > 10:
            q = np.quantile(phi_vals, 0.1)
            mask = phi_vals <= q
            if np.any(mask):
                x_low = x[mask]
                z_low = z[mask]
                x_min, x_max = float(np.min(x_low)), float(np.max(x_low))
                z_min, z_max = float(np.min(z_low)), float(np.max(z_low))
                rect = Rectangle((x_min, z_min), x_max - x_min, z_max - z_min,
                                 fill=False, edgecolor="red", linewidth=1.2, linestyle="--")
                ax.add_patch(rect)
                ax.annotate(
                    "low φ",
                    xy=(0.5 * (x_min + x_max), 0.5 * (z_min + z_max)),
                    xytext=(x_min, z_max),
                    arrowprops=dict(arrowstyle="->", color="red", lw=1.0),
                    color="red",
                    fontsize=8,
                )

    cbar = fig.colorbar(scatter_handles[0], ax=axs, fraction=0.03, pad=0.02)
    cbar.set_label("critic output φ")
    fig.suptitle("Phase Space Discrimination (x-z projection)", y=1.02, fontsize=13)
    plt.tight_layout()
    return fig


def plot_l63_projection(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    s_hat: np.ndarray,    # (T,) or (T,1)
    max_t: int | None = None,
):

    # ── Shape checks
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")
    s_h = _as_1d_timeseries(s_hat, "s_hat")

    T = u_true.shape[0]
    if s_t.shape[0] != T or s_h.shape[0] != T:
        raise ValueError(
            f"Time dimension mismatch: u_true has T={T}, "
            f"s_true={s_t.shape[0]}, s_hat={s_h.shape[0]}"
        )

    if max_t is not None:
        u_true = u_true[:max_t]
        s_t = s_t[:max_t]
        s_h = s_h[:max_t]
        T = len(u_true)

    t = np.arange(T)

    s_t_n = _normalize(s_t)
    s_h_n = _normalize(s_h)

    fig, axs = plt.subplots(2, 2, figsize=(12, 9))

    # ── Top-left: Lorenz states
    axs[0, 0].plot(t, u_true[:, 0], lw=0.8, label="x")
    axs[0, 0].plot(t, u_true[:, 1], lw=0.8, label="y")
    axs[0, 0].plot(t, u_true[:, 2], lw=0.8, label="z")
    axs[0, 0].set_title("Lorenz-63 states")
    axs[0, 0].legend(frameon=False)

    # ── Top-right: learned projections
    axs[0, 1].plot(t, s_t_n, lw=1.5, label="s_true")
    axs[0, 1].plot(t, s_h_n, lw=1.5, ls="--", label="s_hat")
    axs[0, 1].set_title("Projections (normalized)")
    axs[0, 1].legend(frameon=False)

    # ── Bottom-left: observable alignment
    axs[1, 0].scatter(s_t_n, s_h_n, s=6, alpha=0.5)
    lim = max(np.abs(s_t_n).max(), np.abs(s_h_n).max())
    axs[1, 0].plot([-lim, lim], [-lim, lim], "k--", lw=1)
    axs[1, 0].set_xlabel("s_true")
    axs[1, 0].set_ylabel("s_hat")
    axs[1, 0].set_title("Observable alignment")

    # ── Bottom-right: phase-space coloring
    # Create a 2D histogram for contourf
    z_vals = u_true[:, 2]
    y_vals = u_true[:, 1]
    
    # Create grid
    n_bins = 50
    z_edges = np.linspace(z_vals.min(), z_vals.max(), n_bins + 1)
    y_edges = np.linspace(y_vals.min(), y_vals.max(), n_bins + 1)
    
    # Compute mean s_t value in each bin
    statistic, z_edges_out, y_edges_out, _ = binned_statistic_2d(
        z_vals, y_vals, s_t, statistic='mean', bins=(n_bins, n_bins)
    )
    
    z_centers = (z_edges[:-1] + z_edges[1:]) / 2
    y_centers = (y_edges[:-1] + y_edges[1:]) / 2
    Z, Y = np.meshgrid(z_centers, y_centers)
    
    # Interpolate missing (NaN) values using nearest neighbor
    
    # Get valid (non-NaN) points
    valid_mask = ~np.isnan(statistic)
    valid_points = np.array([Z.T[valid_mask], Y.T[valid_mask]]).T
    valid_values = statistic[valid_mask]
    
    # Interpolate to fill NaN values
    all_points = np.array([Z.T.ravel(), Y.T.ravel()]).T
    interpolated = griddata(valid_points, valid_values, all_points, method='nearest')
    statistic_filled = interpolated.reshape(statistic.shape)
    
    cf = axs[1, 1].contourf(
        Z, Y, statistic_filled.T,
        levels=20,
        cmap="viridis",
    )
    axs[1, 1].set_xlabel("z")
    axs[1, 1].set_ylabel("y")
    axs[1, 1].set_title("Attractor colored by s_true")
    plt.colorbar(cf, ax=axs[1, 1], fraction=0.046)


    plt.tight_layout()
    return fig

def plot_l63_attractor_colored_3d(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    max_t: int | None = None,
):
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")

    T = u_true.shape[0]
    if s_t.shape[0] != T:
        raise ValueError(
            f"Time dimension mismatch: u_true has T={T}, "
            f"s_true={s_t.shape[0]}"
        )

    if max_t is not None:
        u_true = u_true[:max_t]
        s_t = s_t[:max_t]

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(1, 1, 1, projection="3d")

    sc = ax.scatter(
        u_true[:, 0],
        u_true[:, 1],
        u_true[:, 2],
        c=s_t,
        s=6,
        cmap="viridis",
    )
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    ax.set_title("3D attractor colored by s_true")
    plt.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    return fig

def _as_1d_timeseries(x: np.ndarray, name: str) -> np.ndarray:
    x = np.asarray(x)
    if x.ndim == 2 and x.shape[1] == 1:
        return x[:, 0]
    if x.ndim == 1:
        return x
    raise ValueError(f"{name} must have shape (T,) or (T,1), got {x.shape}")


def _compute_psd(x: np.ndarray, dt: float = 1.0):
    """
    Compute one-sided power spectral density using FFT.
    """
    x = x - x.mean()
    T = len(x)

    freqs = np.fft.rfftfreq(T, d=dt)
    fft_vals = np.fft.rfft(x)
    psd = (np.abs(fft_vals) ** 2) / T

    return freqs, psd

def plot_l63_projection_psd(
    u_true: np.ndarray,   # (T, 3)
    s_true: np.ndarray,   # (T,) or (T,1)
    s_hat: np.ndarray,    # (T,) or (T,1)
    dt: float = 1.0,
    max_freq: float | None = None,
):
    """
    PSD comparison for Lorenz-63 learned projections.
    """

    # ── Shape checks
    u_true = np.asarray(u_true)
    if u_true.ndim != 2 or u_true.shape[1] != 3:
        raise ValueError(f"u_true must have shape (T,3), got {u_true.shape}")

    s_t = _as_1d_timeseries(s_true, "s_true")
    s_h = _as_1d_timeseries(s_hat, "s_hat")

    if not (len(u_true) == len(s_t) == len(s_h)):
        raise ValueError("Time dimension mismatch")

    # ── Compute PSDs
    freqs_x, psd_x = _compute_psd(u_true[:, 0], dt)
    freqs_y, psd_y = _compute_psd(u_true[:, 1], dt)
    freqs_z, psd_z = _compute_psd(u_true[:, 2], dt)

    freqs_s, psd_s_true = _compute_psd(s_t, dt)
    _, psd_s_hat = _compute_psd(s_h, dt)

    # ── Plot
    fig, ax = plt.subplots(figsize=(8, 5))

    ax.loglog(freqs_x, psd_x, alpha=0.3, label="x")
    ax.loglog(freqs_y, psd_y, alpha=0.3, label="y")
    ax.loglog(freqs_z, psd_z, alpha=0.3, label="z")

    ax.loglog(freqs_s, psd_s_true, lw=2.5, label="s_true")
    ax.loglog(freqs_s, psd_s_hat, lw=2.5, ls="--", label="s_hat")

    if max_freq is not None:
        ax.set_xlim(freqs_s[1], max_freq)

    ax.set_xlabel("Frequency")
    ax.set_ylabel("Power spectral density")
    ax.set_title("PSD of Lorenz-63 states and learned projections")
    ax.legend(frameon=False)
    ax.grid(True, which="both", ls="--", alpha=0.3)

    plt.tight_layout()
    return fig
