from __future__ import annotations
import io, os, warnings
from typing import Tuple, Optional, List
import numpy as np
import torch, gym
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from stable_baselines3.common.callbacks import BaseCallback
__all__ = ["VisualizerCallback", "_generate_all"]
def _plot_to_image(fig: plt.Figure) -> np.ndarray:
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)
    buf.seek(0)
    img = np.asarray(Image.open(buf))[:, :, :3]
    return np.transpose(img, (2, 0, 1))
def _fft_smoothness(actions: np.ndarray, fs: float = 1.0) -> float:
    a = np.array(actions, dtype=float)
    if a.ndim == 1:
        a = a[:, None]
    n = a.shape[0]
    if n < 2:
        return 0.0
    yf = np.fft.fft(a, axis=0)
    yf = np.abs(yf[: n // 2, :])
    freqs = np.fft.fftfreq(n, d=1/fs)[: n // 2]
    freqs = freqs.reshape(-1, 1)
    smooth_per_dim = (2.0 / (n * fs)) * np.sum(freqs * yf, axis=0)
    return float(np.mean(smooth_per_dim))
def _avg_fft_plot(action_seqs: List[np.ndarray], fs: float = 1.0) -> plt.Figure:
    if not action_seqs:
        fig = plt.figure(); plt.text(0.5, 0.5, "No data", ha="center"); return fig
    min_T = min(s.shape[0] for s in action_seqs)
    data = np.stack([s[:min_T] for s in action_seqs], 0)
    fft_vals = np.fft.rfft(data, axis=1)
    freqs = np.fft.rfftfreq(min_T, d=1 / fs)
    mag = np.abs(fft_vals).mean(axis=(0, 2))
    fig = plt.figure(figsize=(8, 4))
    plt.plot(freqs, mag, label="Avg FFT"); plt.grid(True)
    plt.xlabel("Freq [Hz]"); plt.ylabel("Amplitude"); plt.legend(); plt.tight_layout()
    return fig
def _get_actor(model):
    return model.policy.actor
def _get_q1_fn(model):
    critic = model.policy.critic
    return lambda s, a: critic(s, a)[0]
def _grad_Q(state, action, q1_fn):
    a_var = action.clone().detach().requires_grad_(True)
    q_val = q1_fn(state, a_var)
    grad, = torch.autograd.grad(q_val.sum(), a_var)
    return grad
def _surface(ax, X, Y, Z, title, zlab):
    ax.plot_surface(X, Y, Z, cmap="viridis", rstride=1, cstride=1, linewidth=0)
    ax.set_title(title); ax.set_xlabel("x₁"); ax.set_ylabel("x₂"); ax.set_zlabel(zlab)
def _contour(ax, X, Y, Z, title: str, clab: str):
    cf = ax.contourf(X, Y, Z, levels=20, cmap="viridis")
    ax.set_title(title)
    ax.set_xlabel("x₁")
    ax.set_ylabel("x₂")
    if hasattr(ax, "set_box_aspect"):
        ax.set_box_aspect(1)
    else:
        ax.set_aspect("equal", adjustable="box")
    return cf
def _generate_all(model,
                  env_func: function,
                  writer: SummaryWriter,
                  step: int,
                  *,
                  grid: int = 50,
                  boundary: Optional[Tuple[float, float]] = None,
                  eval_eps: int = 10,
                  max_steps: int = 1000,
                  png_dir: Optional[str] = None):
    device = model.device
    actor  = _get_actor(model)
    q1_fn  = _get_q1_fn(model)
    N_SAMPLES   = 32
    DELTA_SCALE = 0.01
    if actor is None:
        warnings.warn("Actor network not accessible – skip viz"); return
    env = env_func()
    low, high = env.observation_space.low, env.observation_space.high
    if boundary is not None:
        low[:], high[:] = boundary[0], boundary[1]
    x1 = np.linspace(low[0], high[0], grid)
    x2 = np.linspace(low[1], high[1], grid)
    X, Y   = np.meshgrid(x1, x2)
    Z_act  = np.zeros_like(X)
    Z_q    = np.zeros_like(X)
    Z_qg   = np.zeros_like(X)
    p_lip, qg_lip, hess_sv = [], [], []
    scale   = getattr(model.policy, "scale_action",   lambda a: a)
    unscale = getattr(model.policy, "unscale_action", lambda a: a)
    for i in range(grid):
        for j in range(grid):
            obs_np = np.array([X[i, j], Y[i, j]], dtype=np.float32)
            act_env, _ = model.predict(obs_np, deterministic=True)
            Z_act[i, j] = act_env[0]
            s_torch = torch.tensor(obs_np, dtype=torch.float32, device=device).unsqueeze(0)
            a_raw   = torch.tensor(scale(act_env), dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                Z_q[i, j] = q1_fn(s_torch, a_raw).cpu().numpy()[0, 0]
            g = _grad_Q(s_torch, a_raw, q1_fn)
            Z_qg[i, j] = g.detach().cpu().numpy()[0, 0]
            local_p_max = 0.0
            local_qg_max = 0.0
            local_hess_max = 0.0
            for _ in range(N_SAMPLES):
                delta = ((torch.rand_like(s_torch) - 0.5) * DELTA_SCALE * torch.tensor(high - low, device=device))
                s2    = s_torch + delta
                d     = torch.norm(delta).item()
                act2_env, _ = model.predict(s2.cpu().numpy().squeeze(), deterministic=True)
                a2_raw  = torch.tensor(scale(act2_env), dtype=torch.float32, device=device).unsqueeze(0)
                val_p = torch.norm(a_raw - a2_raw).item() / d
                if val_p > local_p_max:
                    local_p_max = val_p
                g2 = _grad_Q(s2, a_raw, q1_fn)
                val_qg = torch.norm(g - g2).item() / d
                if val_qg > local_qg_max:
                    local_qg_max = val_qg
            p_lip.append(local_p_max)
            qg_lip.append(local_qg_max)
    writer.add_scalar("grid/policy_lip_mean", float(np.mean(p_lip)) if p_lip else 0.0, step)
    writer.add_scalar("grid/qg_lip_mean",    float(np.mean(qg_lip)) if qg_lip else 0.0, step)
    writer.add_scalar("grid/policy_lip_max", float(np.max(p_lip)) if p_lip else 0.0, step)
    writer.add_scalar("grid/qg_lip_max",    float(np.max(qg_lip)) if qg_lip else 0.0, step)
    for tag, Z, title, zlab in [
        ("policy3d", Z_act, "Policy Surface", "Action (env-scale)"),
        ("qgrad3d",  Z_qg,  "Q-Grad Surface", "∂Q/∂a"),
        ("q3d",      Z_q,   "Q Surface",      "Q"),
    ]:
        fig = plt.figure(figsize=(8, 6))
        _surface(fig.add_subplot(111, projection="3d"), X, Y, Z, title, zlab)
        writer.add_image(f"{tag}", _plot_to_image(fig), step)
        if png_dir: fig.savefig(os.path.join(png_dir, f"{tag}_{step}.png")); plt.close(fig)
    for tag, Z, title, clab in [
        ("policy_contour", Z_act, "Policy Contour", "a (env)"),
        ("qgrad_contour",  Z_qg,  "∂Q/∂a Contour",  "∂Q/∂a"),
        ("q_contour",      Z_q,   "Q Contour",      "Q"),
    ]:
        fig, ax = plt.subplots(figsize=(6, 5))
        cf = _contour(ax, X, Y, Z, title, clab); plt.colorbar(cf, ax=ax)
        writer.add_image(f"{tag}", _plot_to_image(fig), step)
        if png_dir: fig.savefig(os.path.join(png_dir, f"{tag}_{step}.png")); plt.close(fig)
    fluct, smooth, seqs, traj = [], [], [], []
    for _ in range(eval_eps):
        obs = env.reset()
        done = False
        last_a, _ = model.predict(obs, deterministic=True)
        ep_seq = []
        while not done and len(ep_seq) < max_steps:
            a, _ = model.predict(obs, deterministic=True)
            obs, _, done, _ = env.step(a)
            ep_seq.append(a); traj.append(obs[0])
            fluct.append(np.linalg.norm(a - last_a)); last_a = a
        if ep_seq:
            seqs.append(np.asarray(ep_seq)); smooth.append(_fft_smoothness(ep_seq))
    writer.add_scalar("ep/action_fluct", np.mean(fluct) if fluct else 0, step)
    writer.add_scalar("ep/smooth",       np.mean(smooth) if smooth else 0, step)
    if seqs:
        fig_fft = _avg_fft_plot(seqs)
        writer.add_image("ep/fft_avg", _plot_to_image(fig_fft), step)
        if png_dir: fig_fft.savefig(os.path.join(png_dir, f"fft_{step}.png")); plt.close(fig_fft)
    if traj:
        fig = plt.figure(figsize=(8,3)); plt.plot(traj); plt.axhline(0, color="r", ls="--")
        plt.title("Particle x₁ trajectory"); plt.xlabel("t"); plt.ylabel("x₁")
        writer.add_image("particle_traj", _plot_to_image(fig), step)
        if png_dir: fig.savefig(os.path.join(png_dir, f"traj_{step}.png")); plt.close(fig)
    env.close()