"""
Utility functions for creating visualization videos of trajectories.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import imageio.v3 as iio
from pathlib import Path


def create_field_video(
    ground_truth: np.ndarray,
    prediction: np.ndarray,
    loss: np.ndarray,
    output_path: Path,
    title: str = "Field",
    fps: int = 1,
) -> None:
    """Create a video visualization of the physical field over time.

    Parameters
    ----------
    ground_truth : np.ndarray
        Ground truth data array of shape (timesteps, height, width, channels)
    prediction : np.ndarray
        Prediction data array of shape (timesteps, height, width, channels)
    loss : np.ndarray
        Loss data array of shape (timesteps, channels)
    output_path : Path
        Path where the output video will be saved
    title : str, optional
        Title prefix for the video, by default "Field"
    fps : int, optional
        Frames per second for the video, by default 30
    """
    channel_names = ["p", "rho", "T", "u", "v"]
    for channel, name in enumerate(channel_names):
        gt = ground_truth[..., channel]
        pred = prediction[..., channel]

        # Create normalized colormap for consistent color scaling
        norm_gt = Normalize(vmin=gt.min(), vmax=gt.max())
        norm_pred = Normalize(vmin=pred.min(), vmax=pred.max())

        # Create frames
        frames = []
        for t in range(ground_truth.shape[0]):
            # Create figure and plot
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            im_gt = ax[0].imshow(gt[t, :, :], norm=norm_gt)
            im_pred = ax[1].imshow(pred[t, :, :], norm=norm_gt)

            ax[0].set_title(f"Ground Truth - {name} (t={t})")
            ax[1].set_title(
                f"Prediction - {name} (t={t})\nLoss: {loss[t, channel]:.6f}"
            )

            # add colorbar for each subplot
            fig.colorbar(im_gt, ax=ax[0])
            fig.colorbar(im_pred, ax=ax[1])

            # Convert plot to image array
            fig.canvas.draw()
            frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
            frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
            frames.append(frame)
            plt.close()

        # Save as GIF
        vid_path = output_path / f"{title}_{name}.gif"
        print(f"Writing {len(frames)} frames to {vid_path}")
        iio.imwrite(vid_path, frames, fps=fps)
        print(f"Animation saved to {vid_path}")


def generate_channel_gif(
    data: np.ndarray, output_path: Path, fps: int = 1, cmap: str = "viridis"
) -> None:
    """Generate a GIF for a single channel of data over time.

    Parameters
    ----------
    data : np.ndarray
        Data array of shape (timesteps, height, width)
    output_path : Path
        Path where the output GIF will be saved
    fps : int, optional
        Frames per second for the GIF, by default 1
    cmap : str, optional
        Matplotlib colormap name, by default "viridis"
    """
    import matplotlib.cm as cm

    frames = []
    vmin = data.min()
    vmax = data.max()

    # Get colormap
    colormap = cm.get_cmap(cmap)
    norm = Normalize(vmin=vmin, vmax=vmax)

    for t in range(data.shape[0]):
        # Apply colormap to normalized data
        norm_data = norm(data[t, ...])
        colored_frame = colormap(norm_data)
        # Convert to RGB uint8 (remove alpha channel)
        rgb_frame = (colored_frame[:, :, :3] * 255).astype(np.uint8)
        frames.append(rgb_frame)

    print(f"Writing {len(frames)} frames to {output_path}")
    iio.imwrite(output_path, frames, fps=fps, loop=0)
    print(f"Animation saved to {output_path}")
