# experiments/flow_matching_angle_to_angle.py
from pathlib import Path
from matplotlib.patches import Patch

import hydra
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
from hydra.utils import instantiate
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import ImageGrid
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import trange

from group_discovery.geometry_2d import angle_to_vector, wrap_angle
from group_discovery.utils import fig_to_img, sample_random_batch, seed_all


def plot_histogram(angles):
    fig, ax = plt.subplots(figsize=(8, 8))
    bins = np.linspace(-180, 180, 73) - 0.5  # every 5 degrees
    ax.hist(angles, bins=bins, color="k", alpha=0.7)
    ax.set(
        xlabel="Angle (degrees)",
        ylabel="Count",
        xticks=[-180, -135, -90, -45, 0, 45, 90, 135, 180],
        xlim=[-180, 180],
    )

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_angles_grid(x, titles=None):
    nr, nc = 1, x.shape[0]
    fig = plt.figure(figsize=(1.5 * nc, 1.8 * nr), dpi=100)
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(nr, nc),
        axes_pad=0,
        share_all=True,
    )
    # [T+1,B,1] -> [T+1,B,2]
    x = angle_to_vector(x).cpu().numpy()
    for i in range(nc):
        grid[i].set_aspect("equal")
        grid[i].set(xlim=(-1.1, 1.1), ylim=(-1.1, 1.1))
        grid[i].scatter(x[i, :, 0], x[i, :, 1], color="blue", s=2)

        # Remove axes but keep spines (borders)
        grid[i].set_xticks([])
        grid[i].set_yticks([])
        grid[i].set_xlabel("")
        grid[i].set_ylabel("")

        if titles is not None:
            grid[i].set_title(titles[i], pad=10)

    plt.subplots_adjust(top=0.85)  # Adjust top margin to prevent title cutoff
    fig.canvas.draw()

    img_rgb = fig_to_img(fig)
    return img_rgb


def plot_x_t(results):
    fig, ax = plt.subplots(figsize=(12, 6))

    t_vals = results["t_values"].cpu()
    x_t_all = results["x_t_all"].cpu()

    for i in range(x_t_all.shape[1]):
        ax.plot(t_vals, x_t_all[:, i, 0], "-", linewidth=1, alpha=0.1, color="blue")
        ax.plot(t_vals[0], x_t_all[0, i, 0], "o", color="blue", alpha=0.3, markersize=4)
        ax.plot(
            t_vals[-1], x_t_all[-1, i, 0], "o", color="blue", alpha=0.3, markersize=4
        )

    ax.set_xlabel("Time t")
    ax.set_ylabel("x_t")
    ax.set_title("All Trajectories")
    ax.grid(True, alpha=0.3)

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_p_x1_given_xt(results, group_order):
    p_x1_given_xt_all = results["p_x1_given_xt_all"].cpu()  # [n_timesteps, batch, M]
    t_values = results["t_values"].cpu()  # [n_timesteps]
    means = results["component_means"].cpu().sort()[0]  # [n_components, 1]
    diffs = wrap_angle(results["x_t_all"][0].cpu().unsqueeze(1) - means.unsqueeze(0))
    distances = torch.abs(diffs.squeeze(-1))
    # Find closest component
    min_distances, bin_idx = torch.min(distances, dim=1)  # [batch]

    # Create 2x2 layout
    fig, axs = plt.subplots(2, 2, figsize=(12, 6), sharex=True, sharey=True)
    axs = axs.flatten()  # Flatten for easier indexing

    legend_handles = None
    stackplot_colors = None

    # Compute p_x1_given_xt_all over the bins
    for i in range(group_order):
        # Get the indices for the current bin
        mask = (bin_idx == i).squeeze(-1)  # [batch]
        # Get the probabilities for the current bin
        masked_probs = p_x1_given_xt_all[:, mask, :]  # [T, n_samples, group_order]
        p_x1_given_xt_mean = masked_probs.mean(dim=1)  # [T, group_order]

        handles = axs[i].stackplot(
            t_values,
            p_x1_given_xt_mean.T,
            labels=[f"$x_1={means[k].item():.2f}$" for k in range(group_order)],
            alpha=0.8,
        )

        # Extract colors from the stackplot handles
        if stackplot_colors is None:
            stackplot_colors = [h.get_facecolor() for h in handles]

        # Create title with two parts: plain text and colored number
        title_text = "$x_0$ closest to"
        axs[i].set_title(title_text, fontsize=18, pad=8)

        # Add the value with colored background box
        number_text = f"{means[i].item():.2f}"
        bbox_props = dict(
            boxstyle="round,pad=0.2",
            facecolor=stackplot_colors[i],
            alpha=0.8,
            edgecolor="darkgray",
            linewidth=1,
        )

        # Calculate position for the colored number (after the title text)
        axs[i].text(
            0.68,
            1.11,
            number_text,
            transform=axs[i].transAxes,
            fontsize=18,
            ha="left",
            va="center",
            bbox=bbox_props,
            fontweight="bold",
            color="white",
        )

        axs[i].set_ylabel("Probability", fontsize=20)
        if i >= 2:  # Bottom row
            axs[i].set_xlabel("Time $t$", fontsize=20)

        # Increase tick label sizes
        axs[i].tick_params(axis="both", labelsize=12)

        if legend_handles is None:
            legend_handles = handles

    # Add one shared legend at center bottom
    fig.legend(
        legend_handles,
        [f"$x_1 = {means[k].item():.2f}$" for k in range(group_order)],
        loc="lower center",
        ncol=group_order,  # All items in one row
        fontsize=18,
        bbox_to_anchor=(0.5, -0.05),
        frameon=True,
    )

    # Adjust layout and spacing to fit legend at bottom
    fig.tight_layout(rect=[0, 0.02, 1, 0.96])
    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_component_weights(results, x_0, n_samples=5):
    component_means = results["component_means"].cpu()
    if component_means.dim() > 1:
        component_means = component_means.squeeze(
            -1
        )  # Remove extra dimension if present
    t_values = results["t_values"].cpu()  # [n_timesteps]

    fig, axs = plt.subplots(n_samples, 1, figsize=(8, 2 * n_samples), sharex=True)

    if n_samples == 1:
        axs = [axs]

    legend_handles = None
    for i in range(n_samples):
        weights = results["weights_all"][:, i].cpu()  # [n_timesteps, n_components]

        handles = []
        for j, mu in enumerate(component_means):
            # plot() returns a list, so we need [0] to get the Line2D object
            (line,) = axs[i].plot(t_values, weights[:, j], label=f"x_1={mu.item():.2f}")
            handles.append(line)

        if legend_handles is None:
            legend_handles = handles

        if i == n_samples - 1:
            axs[i].set_xlabel("Time t")

        axs[i].set_ylabel("Weight")
        axs[i].set_title(f"Component Weights for x_0={x_0[i, 0].item():.2f}")
        axs[i].set_ylim(0, 1)
        axs[i].grid(True, alpha=0.3)

    # Shared legend on the right
    fig.legend(
        legend_handles,
        [f"$x_1 = {mu.item():.2f}$" for mu in component_means],
        loc="center right",
        bbox_to_anchor=(1.00, 0.5),
        fontsize=10,
    )

    fig.suptitle("Component Weights Over Time for Different $x_0$", fontsize=16, y=0.95)
    fig.tight_layout(rect=[0, 0, 0.85, 0.90])  # leave room on right and top
    fig.subplots_adjust(top=0.90)

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_entropy_p_x1_given_xt(results):
    p_x1_given_xt_all = results["p_x1_given_xt_all"].cpu()  # [n_timesteps, batch, M]
    t_values = results["t_values"].cpu()  # [n_timesteps]
    # Compute entropy for each time step
    eps = 1e-10
    p_safe = torch.clamp(p_x1_given_xt_all, min=eps)
    entropy_all = -(p_safe * torch.log(p_safe)).sum(dim=2)  # [n_timesteps, batch]
    entropy_mean = entropy_all.mean(dim=1).numpy()  # [n_timesteps]
    entropy_std = entropy_all.std(dim=1).numpy()  # [n_timesteps]
    t_values_np = t_values.numpy()

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot mean entropy
    ax.plot(t_values_np, entropy_mean, "b-", linewidth=3, label="Mean entropy")

    # Add confidence band
    ax.fill_between(
        t_values_np,
        entropy_mean - entropy_std,
        entropy_mean + entropy_std,
        alpha=0.3,
        color="blue",
    )

    # Add max entropy line for reference (uniform distribution)
    n_components = p_x1_given_xt_all.shape[2]
    max_entropy = torch.log(torch.tensor(n_components)).item()
    ax.axhline(
        y=max_entropy,
        color="red",
        linestyle="--",
        linewidth=2,
        alpha=0.7,
        label=f"Max entropy (uniform): {max_entropy:.2f}",
    )

    # Labels and formatting with LaTeX
    ax.set_xlabel("Time $t$", fontsize=22)
    ax.set_ylabel(r"Entropy $H_p(x_1 | x_t)$ [nats]", fontsize=22)

    # Increase tick label sizes
    ax.tick_params(axis="both", labelsize=14)

    # Grid
    ax.grid(True, alpha=0.3, linewidth=0.5)

    # Set y-limits with some padding
    y_min = min(0, entropy_mean.min() - 0.1)
    y_max = max(max_entropy + 0.1, entropy_mean.max() + entropy_std.max() + 0.1)
    ax.set_ylim(y_min, y_max)
    ax.set_xlim(0, 1)

    plt.tight_layout()
    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_velocity_field(model, n_x=100, n_t=50, component_means=None):
    """
    Visualize the learned velocity field v(x, t) for SO(2) flow matching.

    Args:
        model: Trained flow model
        n_x: Number of spatial grid points
        n_t: Number of time grid points
        component_means: Optional tensor of target component locations for overlay
    """
    model.eval()
    device = model.device

    # Create grids
    x = torch.linspace(-np.pi, np.pi - 1e-8, n_x)
    t = torch.linspace(0, 1, n_t)
    X, T = torch.meshgrid(x, t, indexing="ij")

    # Flatten for model evaluation
    X_flat = X.reshape(-1, 1).to(device)
    T_flat = T.reshape(-1, 1).to(device)

    # Compute velocity field
    with torch.no_grad():
        V_flat = model(X_flat, T_flat)
        V = V_flat.cpu().reshape(X.shape)

    # Create single figure
    fig, ax = plt.subplots(figsize=(12, 8))

    # Main velocity field heatmap
    im = ax.contourf(T, X, V, levels=50, cmap="RdBu_r", vmin=-np.pi, vmax=np.pi)
    ax.set_xlabel("Time t", fontsize=14)
    ax.set_ylabel("Position x (radians)", fontsize=14)
    ax.set_title("Velocity Field v(x, t) on SO(2)", fontsize=16)

    # Add contour lines
    contours = ax.contour(T, X, V, levels=10, colors="black", alpha=0.3, linewidths=0.5)
    ax.clabel(contours, inline=True, fontsize=8, fmt="%.2f")

    # Mark component locations if provided
    if component_means is not None:
        for i, mean in enumerate(component_means):
            ax.axhline(
                y=mean.item(), color="green", linestyle="--", alpha=0.7, linewidth=2
            )
            ax.text(
                0.02,
                mean.item() + 0.1,
                f"μ_{i}",
                fontsize=12,
                color="green",
                bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.7),
            )

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, label="Velocity (rad/s)")
    cbar.ax.tick_params(labelsize=10)

    # Grid
    ax.grid(True, alpha=0.2)

    plt.tight_layout()
    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_velocity_streamlines(
    model,
    times=None,
    n_trajectories=9,
    n_steps=100,
    component_means=None,
    x_0_samples=None,
):
    """
    Plot streamlines showing particle evolution with velocity quivers at different times.
    """
    model.eval()
    device = model.device

    # Use default times if not provided
    if times is None:
        times = np.linspace(0.1, 1, 10)

    # 1x10 layout with overlapping subplots
    fig, axes = plt.subplots(1, 10, figsize=(36, 6))
    axes = axes.flatten()

    # Generate starting points
    if x_0_samples is None:
        x_0 = (
            torch.linspace(-np.pi, np.pi - 1e-12, n_trajectories).view(-1, 1).to(device)
        )
    else:
        x_0 = x_0_samples.to(device)
        n_trajectories = x_0.shape[0]

    # Integrate trajectories
    trajectories = model.sample_all(
        n_trajectories, n_steps, x_0=x_0
    ).cpu()  # [T+1, B, 1]

    # Create color map for trajectories
    colors = cm.plasma(np.linspace(0, 1.0, n_trajectories))

    for ax_idx, t_slice in enumerate(times):
        ax = axes[ax_idx]

        # Get particle positions at this time
        t_idx = int(t_slice * n_steps)
        positions = trajectories[t_idx, :, 0].numpy()

        # Evaluate velocity field at particle positions
        x_eval = torch.from_numpy(positions).view(-1, 1).to(device).float()
        t_eval = torch.full((n_trajectories, 1), t_slice, device=device)

        with torch.no_grad():
            velocities = model(x_eval, t_eval).cpu().squeeze().numpy()

        # Convert to Cartesian coordinates
        x_cart = np.cos(positions)
        y_cart = np.sin(positions)

        # Plot particles with colors and smaller size
        ax.scatter(
            x_cart,
            y_cart,
            c=colors,
            s=150,
            alpha=0.8,
            edgecolors="black",
            linewidth=1,
            zorder=5,
        )

        # Add velocity arrows with same colors
        for i in range(n_trajectories):
            angle = positions[i]
            velocity = velocities[i]

            # Direction tangent to circle (perpendicular to radius)
            # Positive velocity means counter-clockwise
            dx = -np.sin(angle) * velocity * 0.3  # scale for visibility
            dy = np.cos(angle) * velocity * 0.3

            if abs(velocity) > 0.05:  # Show all but tiny velocities
                ax.arrow(
                    x_cart[i],
                    y_cart[i],
                    dx,
                    dy,
                    head_width=0.06,
                    head_length=0.06,
                    fc=colors[i],
                    ec=colors[i],
                    alpha=0.9,
                    linewidth=6,
                    zorder=4,
                )

        # Draw circle
        circle = plt.Circle((0, 0), 1, fill=False, edgecolor="gray", linewidth=2)
        ax.add_patch(circle)

        # Mark component locations with NESW->RYGB colors
        if component_means is not None:
            for idx, mean in enumerate(component_means):
                x = np.cos(mean.item())
                y = np.sin(mean.item())

                # Determine color based on angular position (NESW = RYGB)
                angle = mean.item()
                # Normalize angle to [0, 2π]
                angle_norm = angle % (2 * np.pi)

                if np.pi / 4 <= angle_norm < 3 * np.pi / 4:  # Near North (top)
                    color = "red"
                elif (
                    angle_norm < np.pi / 4 or angle_norm >= 7 * np.pi / 4
                ):  # Near East (right)
                    color = "yellow"
                elif 5 * np.pi / 4 <= angle_norm < 7 * np.pi / 4:  # Near South (bottom)
                    color = "green"
                else:  # Near West (left), 3π/4 to 5π/4
                    color = "blue"

                ax.plot(
                    x,
                    y,
                    "X",
                    color=color,
                    markersize=30,  # Bigger markers
                    markeredgecolor="black",
                    markeredgewidth=2,
                    zorder=6,
                    alpha=0.6,
                )

        # Formatting
        ax.set_xlim(-1.4, 1.4)
        ax.set_ylim(-1.4, 1.4)
        ax.set_aspect("equal")
        ax.set_title(
            f"t = {t_slice:.2f}", fontsize=36, ha="center", va="center", zorder=10
        )

        ax.grid(True, alpha=0.3)

        # Remove axis ticks and labels
        ax.set_xticks([])
        ax.set_yticks([])

        # Remove all borders/spines
        ax.spines["top"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["right"].set_visible(False)

    # Use negative wspace to create overlap between subplots
    plt.subplots_adjust(wspace=-0.1, hspace=0, left=0, right=1, top=1, bottom=0)

    fig.canvas.draw()

    img_rgb = fig_to_img(fig)

    return img_rgb


def plot_velocity_phase_portrait(model, n_grid=30, component_means=None):
    """
    Create a phase portrait showing velocity field at different times.
    """
    model.eval()
    device = model.device

    # 2x5 layout with shared axes
    fig, axes = plt.subplots(2, 5, figsize=(24, 6), sharex=True, sharey=True)
    axes = axes.flatten()

    # Times from 0.1 to 1.0
    time_points = np.linspace(0.1, 0.99, 10)

    theta = torch.linspace(-np.pi, np.pi - 1e-8, n_grid)

    for ax, t_val in zip(axes, time_points):
        # Create grid on circle
        # Evaluate velocity field
        x_grid = theta.view(-1, 1).to(device)
        t_grid = torch.full((n_grid, 1), t_val, device=device)

        with torch.no_grad():
            v_grid = model(x_grid, t_grid).cpu().squeeze()

        # Create phase portrait
        theta_np = theta.numpy()
        v_np = v_grid.numpy()

        # Plot velocity as function of position
        ax.plot(theta_np, v_np, "k-", linewidth=3)
        ax.axhline(y=0, color="gray", linestyle="-", alpha=0.5, linewidth=2)

        # Fill positive velocities with blue, negative with red
        ax.fill_between(
            theta_np,
            0,
            v_np,
            where=(v_np >= 0),
            alpha=0.3,
            color="blue",
            interpolate=True,
        )
        ax.fill_between(
            theta_np,
            0,
            v_np,
            where=(v_np < 0),
            alpha=0.3,
            color="red",
            interpolate=True,
        )

        # Mark fixed points (where v = 0)
        zero_crossings = np.where(np.diff(np.sign(v_np)))[0]
        if len(zero_crossings) > 0:
            for zc in zero_crossings:
                # Linear interpolation for more accurate zero location
                theta_zero = theta_np[zc] + (theta_np[zc + 1] - theta_np[zc]) * (
                    -v_np[zc] / (v_np[zc + 1] - v_np[zc])
                )
                ax.plot(
                    theta_zero,
                    0,
                    "ro",
                    markersize=8,
                    markeredgecolor="darkred",
                    markeredgewidth=2,
                )

        # Mark component locations
        if component_means is not None:
            for mean in component_means:
                ax.axvline(
                    x=mean.item(),
                    color="green",
                    linestyle="--",
                    alpha=0.7,
                    linewidth=2,
                )

        # Only set labels for edge subplots
        if ax in axes[5:]:  # Bottom row
            ax.set_xlabel("Position x (rad)", fontsize=22)
        if ax in [axes[0], axes[5]]:  # Left column
            ax.set_ylabel("Velocity v(x, t)", fontsize=22)

        # Add time as text in subplot
        ax.text(
            0.5,
            0.96,
            f"t = {t_val:.1f}",
            transform=ax.transAxes,
            fontsize=22,
            horizontalalignment="center",
            verticalalignment="top",
        )

        ax.grid(True, alpha=0.3)
        ax.set_xlim(-np.pi, np.pi)
        ax.tick_params(axis="both", labelsize=14)

    # Add shared legend with filled boxes instead of lines

    legend_elements = [
        Patch(
            facecolor="blue",
            edgecolor="black",
            alpha=0.3,
            linewidth=2,
            label="v > 0 (CCW)",
        ),
        Patch(
            facecolor="red",
            edgecolor="black",
            alpha=0.3,
            linewidth=2,
            label="v < 0 (CW)",
        ),
    ]
    fig.legend(
        handles=legend_elements,
        loc="center right",
        fontsize=16,
        frameon=True,
        fancybox=False,
        shadow=False,
        borderpad=0.5,
        bbox_to_anchor=(0.97, 0.52),
    )

    # Remove space between subplots
    plt.subplots_adjust(
        wspace=0.05, hspace=0.05, left=0.08, right=0.87, top=0.96, bottom=0.08
    )

    fig.canvas.draw()
    img_rgb = fig_to_img(fig)

    return img_rgb


def create_animations(x_t, x_1, n_steps, titles, save_path):
    shift_min = np.array([-5, 0])
    shift_max = np.array([5, 0])

    # [T+1, 2]
    dots_shift = np.linspace(shift_min, shift_max, n_steps + 1)

    # [T+1, B, 2]
    x_t_shifted = angle_to_vector(x_t).cpu().numpy() + dots_shift[:, None, :]

    # Data
    x_1_shifted = angle_to_vector(x_1).cpu().numpy() + shift_max[None, :]

    # Set up the plot
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.axis("off")
    ax.set_xlim(shift_min[0] * 1.3, shift_max[0] * 1.3)
    ax.set_ylim(-1.6, 1.6)
    ax.set_aspect("equal")
    ax.text(shift_min[0], 1.3, titles[0], fontsize=16, ha="center", color="blue")
    ax.text(shift_max[0], 1.3, titles[1], fontsize=16, ha="center", color="blue")

    # Plot static source and target points
    ax.scatter(
        x_t_shifted[0, :, 0],
        x_t_shifted[0, :, 1],
        c="blue",
        alpha=0.2,
        label=titles[0],
        s=20,
    )
    ax.scatter(
        x_1_shifted[:, 0], x_1_shifted[:, 1], c="blue", alpha=0.2, label=titles[1], s=20
    )

    # Dynamic interpolated points
    flow_dots = ax.scatter(
        x_t_shifted[0, :, 0], x_t_shifted[0, :, 1], c="red", alpha=0.4, s=20
    )
    time_text = ax.text(
        shift_min[0], -1.4, f"t={0:.2f}", fontsize=12, ha="center", color="black"
    )
    times = np.linspace(0, 1, n_steps + 1)

    # Create dummy frames for start and end
    dummy_frames = n_steps // 5
    x_t_shifted = np.pad(
        x_t_shifted, [(dummy_frames, dummy_frames), (0, 0), (0, 0)], mode="edge"
    )
    times = np.pad(times, dummy_frames, mode="edge")
    dots_shift = np.pad(dots_shift, [(dummy_frames, dummy_frames), (0, 0)], mode="edge")

    # Update function
    def update(i):
        flow_dots.set_offsets(x_t_shifted[i])
        time_text.set_text(f"t={times[i]:.2f}")
        time_text.set_x(dots_shift[i, 0])
        return flow_dots, time_text

    # Create animation
    ani = FuncAnimation(
        fig, update, frames=n_steps + 2 * dummy_frames + 1, interval=100, blit=True
    )

    # Save animation
    ani.save(save_path, fps=10, dpi=150)


@hydra.main(version_base=None, config_path="../conf", config_name="train")
def main(cfg):
    # Seed
    seed_all(cfg.seed)
    torch.backends.cudnn.deterministic = True

    # Set up wandb and make checkpoints dir
    wandb_config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    logger = instantiate(
        cfg.logger, wandb_config=wandb_config, _convert_="all", _recursive_=False
    )

    # Datasets/Dataloaders
    train_dataset = instantiate(cfg.dataset.train)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=True,
        num_workers=cfg.dataset.num_workers,
    )
    test_dataset = instantiate(cfg.dataset.test)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=cfg.dataset.batch_size,
        shuffle=False,
        num_workers=cfg.dataset.num_workers,
    )

    # Save dir
    (Path(cfg.save_dir) / "ckpt").mkdir(parents=True, exist_ok=True)

    # Model
    model = instantiate(cfg.model).to(cfg.device)

    # Training
    optimizer = instantiate(cfg.optimizer, model.parameters())

    # For test progression
    times = np.arange(0, cfg.test.n_steps, cfg.test.n_steps // 10).astype(int).tolist()
    times.append(cfg.test.n_steps)
    titles = [f"t={t / cfg.test.n_steps}" for t in times]

    for epoch in trange(cfg.train.epochs, desc="Train"):
        # Train
        loss = model.train_net(train_dataloader, optimizer)
        logger.log_all("train", {"loss": loss}, {"epoch": epoch})

        # Test
        loss = model.eval_net(test_dataloader)
        logger.log_all("test", {"loss": loss}, {"epoch": epoch})

        # Test interval
        if (epoch + 1) % cfg.test.epoch_interval == 0:
            # Plot progression
            B = 100
            # [T+1, B, 1]
            test_progression = model.sample_all(B, cfg.test.n_steps)
            test_progression = test_progression[times]

            progression_img = plot_angles_grid(test_progression, titles + ["Orig"])
            logger.log_image(
                "test/progression",
                progression_img,
                {"epoch": epoch},
            )

            # Plot angles histogram
            B = 10_000
            # [B, 1]
            test_samples = model.sample(B, cfg.test.n_steps)
            angles = np.rad2deg(test_samples.cpu().numpy().flatten())
            angles_hist_img = plot_histogram(angles)
            logger.log_image(
                "test/angles_histogram",
                angles_hist_img,
                {"epoch": epoch},
            )

            # Save model checkpoint
            torch.save(
                model.state_dict(),
                Path(cfg.save_dir) / "ckpt" / f"model.pt",
            )
            torch.save(
                optimizer.state_dict(),
                Path(cfg.save_dir) / "ckpt" / f"optimizer.pt",
            )

    # Compute p(x_1 | x_t)
    B = 1000
    x_0 = model.prior_dist.sample((B,)).to(cfg.device)
    p1_dist = test_dataset.dist
    component_means = p1_dist.locs.squeeze(-1)

    results = model.approx_p_x1_given_xt(x_0, p1_dist, n_steps=cfg.test.n_steps)

    traj_img = plot_x_t(results)
    logger.log_image(
        "test/prob_path",
        traj_img,
        {"epoch": epoch},
    )

    p_x1_given_xt_img = plot_p_x1_given_xt(
        results, cfg.dataset.test.dist.locs.group_order
    )
    logger.log_image(
        "test/p_x1_given_xt",
        p_x1_given_xt_img,
        {"epoch": epoch},
    )

    component_weights_img = plot_component_weights(results, x_0)
    logger.log_image(
        "test/component_weights",
        component_weights_img,
        {"epoch": epoch},
    )
    entropy_p_x1_given_xt_img = plot_entropy_p_x1_given_xt(results)
    logger.log_image(
        "test/entropy_p_x1_given_xt",
        entropy_p_x1_given_xt_img,
        {"epoch": epoch},
    )

    velocity_field_img = plot_velocity_field(
        model,
        n_x=500,
        n_t=cfg.test.n_steps,
        component_means=component_means,
    )
    logger.log_image(
        "test/velocity_field",
        velocity_field_img,
        {"epoch": epoch},
    )

    velocity_streamlines_img = plot_velocity_streamlines(
        model,
        n_trajectories=16 + 1,
        n_steps=cfg.test.n_steps,
        component_means=component_means,
    )
    logger.log_image(
        "test/velocity_streamlines",
        velocity_streamlines_img,
        {"epoch": epoch},
    )

    velocity_phase_img = plot_velocity_phase_portrait(
        model,
        n_grid=100,
        component_means=component_means,
    )
    logger.log_image(
        "test/velocity_phase",
        velocity_phase_img,
        {"epoch": epoch},
    )

    # Test batch
    B = 200
    test_batch = sample_random_batch(test_dataset, B, cfg.device)
    x_t = model.sample_all(B, cfg.test.n_steps)

    # Create progression animation
    create_animations(
        x_t,
        test_batch,
        n_steps=cfg.test.n_steps,
        titles=[model.prior_dist.group, cfg.dataset.group],
        save_path=Path(cfg.save_dir) / "progression.mp4",
    )


if __name__ == "__main__":
    main()
