import pathlib
from typing import Tuple, Union

import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection
from matplotlib.text import Text
def _update_scatter_plot(
    i: int, traj: np.ndarray, scatter_plot: PathCollection, time_text: Text
) -> None:
    """Update scatter plot and time label in-place. This is a helper function for
    `trajectory_to_gif`.

    Args:
        i (int): Current frame index
        traj (np.ndarray): Trajectory of shape (n_frames, n_points, n_dims)
        scatter_plot (PathCollection): Scatter plot
        time_text (Text): Time label
    """
    # Update scatter plot
    scatter_plot.set_offsets(traj[i, ...])
    # Update time label
    time_text.set_text(f"t = {i}")
def save_2d_trajectory_as_gif(
    trajectory: Union[np.ndarray, "torch.Tensor"],
    filename: Union[str, pathlib.Path],
    *,
    fps: int = 10,
    t_skip: int = 1,
    dpi: int = 30,
    xlim: Tuple[int, int] = None,
    ylim: Tuple[int, int] = None,
    figsize: Tuple[int, int] = (8, 8),
    **scatter_kwargs,
) -> matplotlib.animation.FuncAnimation:
    """Plots a trajectory of a 2D diffusion process as an animated GIF.

    Args:
        trajectory (Union[np.ndarray, torch.Tensor]): Trajectory of shape
            (n_frames, n_points, n_dims) with n_dims=2 (x, y) coordinates. If
            `torch.Tensor`, it must be on the CPU.
        filename (str): Filename of the GIF. Must end with .gif.
        fps (int, optional): Frames per second. Defaults to 10.
        t_skip (int, optional): Number of frames to skip between each frame in the
            GIF. Defaults to 1.
        dpi (int, optional): Dots per inch. Defaults to 30.
        xlim (Tuple[int, int], optional): x-axis limits. Defaults to None. If None,
            the limits are set to the minimum and maximum x-coordinates in the
            trajectory.
        ylim (Tuple[int, int], optional): y-axis limits. Defaults to None. If None,
            the limits are set to the minimum and maximum y-coordinates in the
            trajectory.
        figsize (Tuple[int, int], optional): Figure size. Defaults to (8, 8).
        **scatter_kwargs: Additional keyword arguments passed to `matplotlib.pyplot.scatter`.

    Returns:
        matplotlib.animation.FuncAnimation: Animation object (also saved to disk as a GIF
            file at `filename`).
    """

    traj = np.asarray(trajectory)
    filename = pathlib.Path(filename)

    assert filename.suffix == ".gif", "Filename must end with .gif"
    assert traj.ndim == 3, "Trajectory must be of shape (n_frames, n_points, n_dims)"
    assert (
        traj.shape[-1] == 2
    ), "Trajectory must be of shape (n_frames, n_points, n_dims) with n_dims=2"

    n_frames, n_points, n_dims = traj.shape
    xlim = xlim if xlim is not None else (traj[:, :, 0].min(), traj[:, :, 0].max())
    ylim = ylim if ylim is not None else (traj[:, :, 1].min(), traj[:, :, 1].max())

    fig = plt.figure(figsize=figsize)
    plt.xlim(*xlim)
    plt.ylim(*ylim)

    scatter_plot = plt.scatter(x=traj[0, :, 0], y=traj[0, :, 1], **scatter_kwargs)
    # Add time label to plot
    time_text = plt.text(
        0.05,
        0.95,
        f"t = {0}",
        horizontalalignment="left",
        verticalalignment="top",
        transform=plt.gca().transAxes,
    )

    animation = matplotlib.animation.FuncAnimation(
        fig,
        _update_scatter_plot,
        frames=range(0, n_frames, t_skip),
        fargs=(traj, scatter_plot, time_text),
        # TODO: fps /interval currently does not work yet to speed up animation
        interval=1000.0 / fps,
    )

    animation.save(filename, writer="imagemagick", dpi=dpi)
    return filename