import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

from utils import marginal_prob_std, diff_coeff


@torch.no_grad()
def visualize_sampling_2d(
    model, sigma, n=1000, steps=400, xlim=4, mode="sde", name="example"
):
    """
    This is to visualize the integration process of the generative model.
    """
    # Get Device
    device = next(model.parameters()).device

    # Sample from initial distribution
    t = torch.ones(n, device=device)
    x = torch.randn(n, 2, device=device) * marginal_prob_std(t, sigma)[:, None]
    t_range = torch.linspace(1, 1e-5, steps, device=device)

    dt = t_range[0] - t_range[1]

    fig, ax = plt.subplots(figsize=(6, 6))
    scat = ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=5)

    ax.set_xlim(-xlim, xlim)
    ax.set_ylim(-xlim, xlim)
    ax.set_aspect("equal")

    def update(frame):
        nonlocal x
        t_step = t_range[frame]
        t_feed = torch.full((x.size(0),), t_step, device=device)

        # We also have the option to view the ode version of the SDE
        if mode == "sde":
            g = diff_coeff(t_feed, sigma)
            mean = x + (g**2)[:, None] * model(x, t_feed) * dt
            x = mean + torch.sqrt(dt) * g[:, None] * torch.randn_like(x)
        else:
            g = diff_coeff(t_feed, sigma)
            mean = x + (g**2)[:, None] * model(x, t_feed) * dt / 2
            x = mean

        scat.set_offsets(x.detach().cpu().numpy())
        ax.set_title(f"{mode.upper()} — t={t_step:.4f}")
        return (scat,)

    anim = FuncAnimation(fig, update, frames=len(t_range), interval=30)
    anim.save(f"videos/{name}.gif", writer=PillowWriter(fps=30))
    plt.close()


def visualize_loss(losses):
    # --- Plot training loss ---
    plt.figure(figsize=(6, 4))
    plt.plot(losses)
    plt.title("Training Loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.grid(alpha=0.2)
    plt.savefig("loss.png", dpi=300, bbox_inches="tight")
    plt.close()
