"""Plotting utilities for the marble environment."""
from pathlib import Path
import tempfile
from typing import Optional
from trajectory_dynamics import State
import ffmpeg
from matplotlib.patches import Circle
from matplotlib.lines import Line2D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.collections import PatchCollection


# https://stackoverflow.com/a/42972469
class LineDataUnits(Line2D):
    def __init__(self, *args, **kwargs):
        _lw_data = kwargs.pop("linewidth", 1)
        super().__init__(*args, **kwargs)
        self._lw_data = _lw_data

    def _get_lw(self):
        if self.axes is not None:
            ppd = 72./self.axes.figure.dpi
            trans = self.axes.transData.transform
            return ((trans((1, self._lw_data))-trans((0, 0)))*ppd)[1]
        else:
            return 1

    def _set_lw(self, lw):
        self._lw_data = lw

    _linewidth = property(_get_lw, _set_lw)


AxisLimits = tuple[tuple[float, float], tuple[float, float]]


def plot_state(
        state: State,
        goal_position: tuple[float, float],
        goal_radius: float,
        output_path: Path,
        axis_limits: Optional[AxisLimits],
        ) -> AxisLimits:
    """Plot the given state to the given path."""
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Build list of objects
    patches_colors = []

    # Draw marble
    marble = Circle(
        (
            state.marble.position[0].item(),
            state.marble.position[1].item()
        ),
        state.marble.radius,
    )
    patches_colors.append((marble, 0))

    # Draw goal
    goal = Circle(
        (
            goal_position[0],
            goal_position[1]
        ),
        goal_radius,
    )
    patches_colors.append((goal, 1))

    # Draw segments
    lines = [
        LineDataUnits(
            (segment.p1[0], segment.p2[0]),
            (segment.p1[1], segment.p2[1]),
            linewidth=segment.radius*2,
        )
        for segment in state.segments
    ]

    # Draw objects
    patches = [p for p, _ in patches_colors]
    colors = [c for _, c in patches_colors]
    p = PatchCollection(patches, alpha=0.9)
    p.set_array(colors)
    ax.add_collection(p)

    for line in lines:
        ax.add_line(line)

    if axis_limits is None:
        ax.relim()
        ax.autoscale_view()
        ax.axis('equal')
    else:
        ax.set_xlim(axis_limits[0])
        ax.set_ylim(axis_limits[1])

    # Save figure
    fig.savefig(output_path)

    return ax.get_xlim(), ax.get_ylim()


def plot_animation(
        states: list[State],
        fps: int,
        goal_position: tuple[float, float],
        goal_radius: float,
        output_path: Path,
        ):
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        # Plot final state for an additional second
        axis_limits = None
        for i, state in enumerate(states+[states[-1] for _ in range(fps)]):
            path = frame_dir/(f"{i}.png").rjust(10)
            axis_limits = plot_state(
                state,
                goal_position=goal_position,
                goal_radius=goal_radius,
                output_path=path,
                axis_limits=axis_limits,
            )
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
            .output(str(output_path))
            .overwrite_output()
            .run(quiet=True)
        )
