r"""Plotting and image helpers."""

import math
from tempfile import mkstemp
from typing import Optional, Sequence, Tuple, Union

import matplotlib.animation as ani
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sb
import torch
from git import List
from IPython.display import Video
from matplotlib.backends.backend_agg import FigureCanvasAgg
from moviepy import ImageSequenceClip
from numpy.typing import ArrayLike
from PIL import Image, ImageDraw, ImageFont

from crps_retrofitting.metrics.fourier import isotropic_power_spectrum


def animate_fields(
    x: ArrayLike,
    y: Optional[ArrayLike] = None,
    tolerance: float = 1.0,
    fields: Optional[Sequence[str]] = None,
    timesteps: Optional[Sequence[int]] = None,
    cmap: str = "RdBu_r",
    figsize: Tuple[float, float] = (2.4, 2.4),
    fps: float = 4.0,
) -> ani.Animation:
    C, L, _, _ = x.shape

    if torch.is_tensor(x):
        x = x.numpy(force=True)

    if torch.is_tensor(y):
        y = y.numpy(force=True)

    if timesteps is None:
        timesteps = list(range(L))

    if y is None:
        fig, axs = plt.subplots(
            nrows=1,
            ncols=C,
            figsize=(figsize[0] * C, figsize[1]),
            squeeze=False,
        )
    else:
        fig, axs = plt.subplots(
            nrows=3,
            ncols=C,
            figsize=(figsize[0] * C, 3 * figsize[1]),
            squeeze=False,
        )

    artists = []

    for i in range(C):
        vmin = np.quantile(x[i], 0.01) - 1e-2
        vmax = np.quantile(x[i], 0.99) + 1e-2

        if fields:
            axs[0, i].set_title(f"{fields[i]}")

        for j in range(1):
            if y is None:
                img = axs[j, i].imshow(
                    x[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[j, i].set_xticks([])
                axs[j, i].set_yticks([])

                axs[j, 0].set_ylabel("$x_i$")

                artists.append(img)
            else:
                img0 = axs[3 * j, i].imshow(
                    x[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[3 * j, i].set_xticks([])
                axs[3 * j, i].set_yticks([])

                img1 = axs[3 * j + 1, i].imshow(
                    y[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[3 * j + 1, i].set_xticks([])
                axs[3 * j + 1, i].set_yticks([])

                img2 = axs[3 * j + 2, i].imshow(
                    y[i, j] - x[i, j],
                    cmap="RdBu_r",
                    vmin=-tolerance,
                    vmax=tolerance,
                    interpolation="none",
                )
                axs[3 * j + 2, i].set_xticks([])
                axs[3 * j + 2, i].set_yticks([])

                axs[3 * j, 0].set_ylabel("$x_i$")
                axs[3 * j + 1, 0].set_ylabel("$y_i$")
                axs[3 * j + 2, 0].set_ylabel("$y_i - x_i$")

                artists.extend((img0, img1, img2))

    def animate(j):
        for i in range(C):
            if y is None:
                artists[i].set_array(x[i, j])
            else:
                artists[3 * i + 0].set_array(x[i, j])
                artists[3 * i + 1].set_array(y[i, j])
                artists[3 * i + 2].set_array(y[i, j] - x[i, j])

        return artists

    fig.align_labels()
    fig.tight_layout()

    return ani.FuncAnimation(fig, animate, frames=L, interval=int(1000 / fps))


def plot_fields(
    x: ArrayLike,
    y: Optional[ArrayLike] = None,
    tolerance: float = 1.0,
    fields: Optional[Sequence[str]] = None,
    timesteps: Optional[Sequence[int]] = None,
    cmap: str = "RdBu_r",
    figsize: Tuple[float, float] = (2.4, 2.4),
) -> plt.Figure:
    C, L, _, _ = x.shape

    if torch.is_tensor(x):
        x = x.numpy(force=True)

    if torch.is_tensor(y):
        y = y.numpy(force=True)

    if timesteps is None:
        timesteps = list(range(L))

    if y is None:
        fig, axs = plt.subplots(
            nrows=L,
            ncols=C,
            figsize=(figsize[0] * C, figsize[1] * L),
            squeeze=False,
        )
    else:
        fig, axs = plt.subplots(
            nrows=3 * L,
            ncols=C,
            figsize=(figsize[0] * C, 3 * figsize[1] * L),
            squeeze=False,
        )

    for i in range(C):
        vmin = np.quantile(x[i], 0.01) - 1e-2
        vmax = np.quantile(x[i], 0.99) + 1e-2

        if fields:
            axs[0, i].set_title(f"{fields[i]}")

        for j in range(L):
            if y is None:
                axs[j, i].imshow(
                    x[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[j, i].set_xticks([])
                axs[j, i].set_yticks([])

                axs[j, 0].set_ylabel(rf"$x_{{{timesteps[j]}}}$")
            else:
                axs[3 * j, i].imshow(
                    x[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[3 * j, i].set_xticks([])
                axs[3 * j, i].set_yticks([])

                axs[3 * j + 1, i].imshow(
                    y[i, j], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="none"
                )
                axs[3 * j + 1, i].set_xticks([])
                axs[3 * j + 1, i].set_yticks([])

                axs[3 * j + 2, i].imshow(
                    y[i, j] - x[i, j],
                    cmap="RdBu_r",
                    vmin=-tolerance,
                    vmax=tolerance,
                    interpolation="none",
                )
                axs[3 * j + 2, i].set_xticks([])
                axs[3 * j + 2, i].set_yticks([])

                axs[3 * j, 0].set_ylabel(rf"$x_{{{timesteps[j]}}}$")
                axs[3 * j + 1, 0].set_ylabel(rf"$y_{{{timesteps[j]}}}$")
                axs[3 * j + 2, 0].set_ylabel(
                    rf"$y_{{{timesteps[j]}}} - x_{{{timesteps[j]}}}$"
                )

    fig.align_labels()
    fig.tight_layout()

    return fig


def plot_psd(
    x: ArrayLike,
    y: Optional[ArrayLike] = None,
    fields: Optional[Sequence[str]] = None,
    figsize: Tuple[float, float] = (3.2, 3.2),
) -> plt.Figure:
    C, *shape = x.shape

    if torch.is_tensor(x):
        x = x.numpy(force=True)

    if torch.is_tensor(y):
        y = y.numpy(force=True)

    fig, axs = plt.subplots(
        nrows=1,
        ncols=C,
        figsize=(figsize[0] * C, figsize[1]),
        squeeze=False,
    )

    for i in range(C):
        if fields:
            axs[0, i].set_title(f"{fields[i]}")

        p_x, k = isotropic_power_spectrum(x[i], spatial=len(shape))

        axs[0, i].loglog(1 / k, p_x, base=2, label="GT", color="black")

        if y is not None:
            p_y, _ = isotropic_power_spectrum(y[i], spatial=len(shape))

            axs[0, i].loglog(1 / k, p_y, base=2, label="Sample", color="C0", alpha=0.75)

        axs[0, i].invert_xaxis()
        axs[0, i].set_xticks(
            [2**i for i in range(1, math.ceil(math.log2(1 / k[0].item())))]
        )

    axs[0, 0].set_ylabel("power spectrum density")
    axs[0, 0].legend()

    fig.align_labels()
    fig.tight_layout()

    return fig


def field2rgb(
    x: ArrayLike,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    contrast: float = 1.5,
    cmap: str = "RdBu_r",
    bad: str = "whitesmoke",
) -> ArrayLike:
    if torch.is_tensor(x):
        x = x.numpy(force=True)

    if vmin is None:
        vmin = np.nanquantile(x, 0.01) - 1e-2
    if vmax is None:
        vmax = np.nanquantile(x, 0.99) + 1e-2

    palette = sb.color_palette(cmap, as_cmap=True)
    palette.set_bad(bad)

    x = (x - vmin) / (vmax - vmin)
    x = 2 * x - 1
    x = np.sign(x) * np.abs(x) ** (1 / contrast)
    x = (x + 1) / 2
    x = np.clip(x, a_min=0.0, a_max=1.0)
    x = palette(x)
    x = 255 * x[..., :3]
    x = x.astype(np.uint8)

    return x


def draw(
    x: ArrayLike,  # (M, N, H, W)
    pad: Union[float, int] = 1 / 64,
    background: str = "white",
    isolate: Sequence[int] = (),
    zoom: int = 1,
    titles: Sequence[str] = (),
    row_titles: Sequence[str] = (),
    fontsize: int = 32,
    bottom_pad: Union[float, int, None] = None,  # Extra space below last row
    **kwargs,
) -> Image.Image:
    """
    Enhanced version that can handle both column and row titles simultaneously.
    """
    if torch.is_tensor(x):
        x = x.numpy(force=True)

    axes = tuple(i for i in range(x.ndim) if i not in isolate)

    kwargs.setdefault("vmin", np.nanquantile(x, 0.01, axis=axes, keepdims=True) - 1e-2)
    kwargs.setdefault("vmax", np.nanquantile(x, 0.99, axis=axes, keepdims=True) + 1e-2)

    x = field2rgb(x, **kwargs)

    while x.ndim < 5:
        x = x[None]

    M, N, H, W, _ = x.shape

    if isinstance(pad, float):
        pad = int(pad * max(H, W))

    # Calculate extra bottom padding
    if bottom_pad is None:
        bottom_pad = pad  # Default to same as regular pad
    elif isinstance(bottom_pad, float):
        bottom_pad = int(bottom_pad * max(H, W))

    # Calculate space needed for titles
    from PIL import ImageDraw, ImageFont

    font = ImageFont.load_default(size=fontsize)

    # Column titles space
    col_title_height = 0
    if titles:
        temp_img = Image.new("RGB", (1, 1), background)
        temp_draw = ImageDraw.Draw(temp_img)
        _, _, w, col_title_height = temp_draw.textbbox(
            (0, 0), text="".join(titles), font=font
        )
        col_title_height += pad

    # Row titles space
    row_title_width = 0
    if row_titles:
        temp_img = Image.new("RGB", (1, 1), background)
        temp_draw = ImageDraw.Draw(temp_img)
        for title in row_titles:
            bbox = temp_draw.textbbox((0, 0), text=title, font=font)
            title_width = bbox[2] - bbox[0]
            row_title_width = max(row_title_width, title_width)
        row_title_width += pad

    # Create main image with space for titles
    img = Image.new(
        mode="RGB",
        size=(
            row_title_width + N * (W + pad) + pad,
            col_title_height
            + M * (H + pad)
            + pad
            + bottom_pad,  # Add extra bottom space
        ),
        color=background,
    )

    # Paste the data images
    for i in range(M):
        for j in range(N):
            offset = (
                row_title_width + j * (W + pad) + pad,
                col_title_height + i * (H + pad) + pad,
            )

            img.paste(Image.fromarray(x[i][j]), offset)

    if zoom > 1:
        img = img.resize((img.width * zoom, img.height * zoom), Image.NEAREST)
        pad, H, W = pad * zoom, H * zoom, W * zoom
        col_title_height *= zoom
        row_title_width *= zoom
        bottom_pad *= zoom

    # Draw titles
    draw_obj = ImageDraw.Draw(img)

    # Draw column titles
    if titles:
        for j, title in enumerate(titles):
            offset = (
                row_title_width + j * (W + pad) + pad + W // 2,
                col_title_height // 2,
            )
            draw_obj.text(offset, text=title, anchor="mm", fill="black", font=font)

    # Draw row titles
    if row_titles:
        for i, title in enumerate(row_titles):
            offset = (
                row_title_width // 2,
                col_title_height + i * (H + pad) + pad + H // 2,
            )
            draw_obj.text(offset, text=title, anchor="mm", fill="black", font=font)

    return img


def draw_movie(
    x: ArrayLike,  # (T, M, N, H, W)
    file: Optional[str] = None,
    fps: float = 4.0,
    display: bool = False,
    embed: bool = False,
    isolate: Sequence[int] = (),
    **kwargs,
) -> Union[str, Video]:
    if torch.is_tensor(x):
        x = x.numpy(force=True)

    axes = tuple(i for i in range(x.ndim) if i not in isolate)

    kwargs.setdefault(
        "vmin", np.nanquantile(x, 0.01, axis=axes, keepdims=True).squeeze(0) - 1e-2
    )
    kwargs.setdefault(
        "vmax", np.nanquantile(x, 0.99, axis=axes, keepdims=True).squeeze(0) + 1e-2
    )

    imgs = [draw(xi, **kwargs) for i, xi in enumerate(x)]
    imgs = [np.asarray(img) for img in imgs]

    clip = ImageSequenceClip(imgs, fps=fps)

    if file is None:
        _, file = mkstemp(suffix=".mp4")

    if str(file).endswith(".gif"):
        clip.write_gif(file, loop=0, logger=None)
    else:
        clip.write_videofile(file, codec="libx264", logger=None)

    if display:
        return Video(file, embed=embed, width=1280)
    else:
        return file


def plot_psd_movie(
    x: np.ndarray,  # Shape: (C, T, H, W)
    y: Optional[
        Union[np.ndarray, List[np.ndarray]]
    ] = None,  # Single array or list of arrays, each with shape (C, T, H, W)
    x_surrogate: np.ndarray = None,  # Shape: (C, T, H, W),
    fields: Optional[Sequence[str]] = None,
    file: Optional[str] = None,
    fps: float = 4.0,
    display: bool = False,
    embed: bool = False,
    figsize: Tuple[float, float] = (3.2, 3.2),
    dpi: int = 100,
    sample_colors: Optional[List[str]] = None,  # Colors for each sample
    sample_alpha: float = 0.75,
    ylim: Optional[Tuple[float, float]] = None,  # Fixed y-limits for all frames
    timestep_labels: Optional[
        Sequence[int]
    ] = None,  # Timesteps to include in the movie
    **kwargs,
) -> Union[str, Video]:
    """
    Create a movie from PSD evolution over time with fixed y-axis limits for better comparison.

    This version computes PSD for all timesteps first to determine global min/max values
    for consistent y-axis scaling across all frames.

    Args:
        x: Ground truth array with shape (C, T, H, W)
        y: Single sample array or list of sample arrays, each with shape (C, T, H, W)
        x_surrogate: Surrogate prediction array with shape (C, T, H, W)
        fields: Field names for each channel
        file: Output file path
        fps: Frames per second
        display: Whether to display the video
        embed: Whether to embed in notebook
        figsize: Figure size for each subplot
        dpi: DPI for rendering
        sample_colors: List of colors for each sample (defaults to matplotlib color cycle)
        sample_alpha: Alpha value for sample lines
        **kwargs: Additional arguments
    """

    if torch.is_tensor(x):
        x = x.numpy(force=True)

    if torch.is_tensor(x_surrogate):
        x_surrogate = x_surrogate.numpy(force=True)

    # Handle single array or list of arrays for y
    if y is not None:
        if isinstance(y, list):
            y_list = []
            for yi in y:
                if torch.is_tensor(yi):
                    y_list.append(yi.numpy(force=True))
                else:
                    y_list.append(yi)
        else:
            # Single array, convert to list
            if torch.is_tensor(y):
                y_list = [y.numpy(force=True)]
            else:
                y_list = [y]
    else:
        y_list = []

    C, T, H, W = x.shape

    # Validate shapes
    for i, yi in enumerate(y_list):
        assert (
            yi.shape == x.shape
        ), f"y[{i}] shape {yi.shape} doesn't match x shape {x.shape}"

    # Set default colors if not provided
    if sample_colors is None:
        # Use matplotlib's default color cycle
        default_colors = [f"C{i}" for i in range(10)]
        sample_colors = default_colors[: len(y_list)] if y_list else []

    # Pre-compute all PSDs to determine global limits
    all_psd_x = []
    all_psd_x_surrogate = []
    all_psd_y = [[] for _ in range(len(y_list))]  # List of lists for each sample
    k_values = None

    for t in range(T):
        psd_x_t = []
        for c in range(C):
            p_x, k = isotropic_power_spectrum(x[c, t], spatial=2)
            psd_x_t.append(p_x)
            if k_values is None:
                k_values = k  # Assume k is the same for all
        all_psd_x.append(psd_x_t)

        if x_surrogate is not None:
            psd_x_surrogate_t = []
            for c in range(C):
                p_x, k = isotropic_power_spectrum(x_surrogate[c, t], spatial=2)
                psd_x_surrogate_t.append(p_x)
                if k_values is None:
                    k_values = k  # Assume k is the same for all
            all_psd_x_surrogate.append(psd_x_surrogate_t)

        # Compute PSDs for all samples
        for sample_idx, yi in enumerate(y_list):
            psd_y_t = []
            for c in range(C):
                p_y, _ = isotropic_power_spectrum(yi[c, t], spatial=2)
                psd_y_t.append(p_y)
            all_psd_y[sample_idx].append(psd_y_t)

    # Compute global min/max for consistent scaling
    all_values = np.concatenate([np.concatenate(frame) for frame in all_psd_x])
    if x_surrogate is not None:
        all_values = np.concatenate(
            [
                all_values,
                np.concatenate(
                    [np.concatenate(frame) for frame in all_psd_x_surrogate]
                ),
            ]
        )
    for sample_psds in all_psd_y:
        all_values = np.concatenate(
            [
                all_values,
                np.concatenate([np.concatenate(frame) for frame in sample_psds]),
            ]
        )

    global_min = np.min(all_values)
    global_max = np.max(all_values)

    def fig_to_array(fig: plt.Figure, dpi: int = 100) -> np.ndarray:
        """Convert matplotlib figure to numpy array."""
        fig.set_dpi(dpi)
        canvas = FigureCanvasAgg(fig)
        canvas.draw()
        buf = canvas.buffer_rgba()
        array = np.asarray(buf)
        rgb_array = array[:, :, :3]
        plt.close(fig)
        return rgb_array

    def create_psd_plot_for_timestep(t: int, t_label: int, ylim: tuple) -> plt.Figure:
        """Create PSD plot for all channels at timestep t with fixed limits."""
        fig, axs = plt.subplots(
            nrows=1,
            ncols=C,
            figsize=(figsize[0] * C, figsize[1]),
            squeeze=False,
        )

        for c in range(C):
            # Set title with channel name and timestep
            if fields:
                axs[0, c].set_title(f"{fields[c]} (t={t_label})")
            else:
                axs[0, c].set_title(f"Channel {c} (t={t_label})")

            # Plot ground truth
            p_x = all_psd_x[t][c]
            axs[0, c].loglog(
                1 / k_values, p_x, base=2, label="GT", color="black", linewidth=2
            )

            if x_surrogate is not None:
                p_x_surrogate = all_psd_x_surrogate[t][c]
                axs[0, c].loglog(
                    1 / k_values,
                    p_x_surrogate,
                    base=2,
                    label="Surrogate",
                    color="red",
                    linewidth=2,
                )

            # Plot all samples
            for sample_idx, sample_psds in enumerate(all_psd_y):
                p_y = sample_psds[t][c]
                color = (
                    sample_colors[sample_idx]
                    if sample_idx < len(sample_colors)
                    else f"C{sample_idx}"
                )

                # If multiple samples, add index to label, otherwise just "Sample"
                if len(y_list) > 1:
                    label = f"Sample {sample_idx + 1}"
                else:
                    label = "Sample"

                axs[0, c].loglog(
                    1 / k_values,
                    p_y,
                    base=2,
                    label=label,
                    color=color,
                    alpha=sample_alpha,
                    linewidth=1.5,
                )

            axs[0, c].invert_xaxis()
            axs[0, c].set_xticks(
                [2**i for i in range(1, math.ceil(math.log2(1 / k_values[0].item())))]
            )

            # Set consistent y-limits
            if ylim is None:
                axs[0, c].set_ylim(global_min * 0.5, global_max * 2.0)
            else:
                assert len(ylim) == 2, "ylim must be a tuple of (min, max)"
                axs[0, c].set_ylim(ylim[0], ylim[1])

        # Add labels and legend only to first subplot
        axs[0, 0].set_ylabel("power spectrum density")
        if y_list:
            axs[0, 0].legend()

        fig.align_labels()
        fig.tight_layout()

        return fig

    # Generate images for each timestep
    imgs = []
    if timestep_labels is None:
        timestep_labels = range(T)
    for t, t_label in zip(range(T), timestep_labels):
        fig = create_psd_plot_for_timestep(t, t_label, ylim)
        img_array = fig_to_array(fig, dpi=dpi)
        imgs.append(img_array)

    # Create video clip
    clip = ImageSequenceClip(imgs, fps=fps)

    if file is None:
        _, file = mkstemp(suffix=".mp4")

    if str(file).endswith(".gif"):
        clip.write_gif(file, loop=0, logger=None)
    else:
        clip.write_videofile(file, codec="libx264", logger=None)

    print(f"Video saved to: {file}")

    if display:
        return Video(file, embed=embed, width=1280)
    else:
        return file
