"""Plotting and animation functions."""
from pathlib import Path
from terrain_mass.environment import EnvironmentInstance
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.collections import PatchCollection
from matplotlib.patches import Circle
from matplotlib.patches import Polygon
from matplotlib.figure import Figure
from dataclasses import dataclass
import tempfile
import ffmpeg
import torch


@dataclass
class Bounds:
    x_min: float
    x_max: float
    y_min: float
    y_max: float


def get_bounds(
        states: list[torch.Tensor],
        environment: EnvironmentInstance,
        target_position: tuple[float, float] | None,
        ) -> Bounds:
    """Return bounds of the given sequence of states."""
    xs = list()
    ys = list()

    # Account for mass position
    xs.extend((
        environment.get_x_pos(state).item()
        for state in states
    ))
    ys.extend((
        environment.get_y_pos(state).item()
        for state in states
    ))

    # Account for target position
    if target_position is not None:
        xs.append(target_position[0])
        ys.append(target_position[1])

    # Account for islands
    for island in environment.islands:
        points = [island.a, island.b, island.c, island.d]
        for point in points:
            xs.append(point[0])
            ys.append(point[1])

    bounds = Bounds(
        x_min=min(xs),
        x_max=max(xs),
        y_min=min(ys),
        y_max=max(ys),
    )
    return bounds


def plot_state(
        x: torch.Tensor,
        x_min: float,
        x_max: float,
        y_min: float,
        y_max: float,
        mass_radius: float,
        environment: EnvironmentInstance,
        target_position: tuple[float, float] | None,
        output_path: Path,
        ):
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Set limits
    hdim = x_max - x_min
    vdim = y_max - y_min
    max_dim = max(hdim, vdim)
    ax.set_xlim(left=x_min-mass_radius, right=x_min+max_dim+mass_radius)
    ax.set_ylim(bottom=y_min-mass_radius, top=y_min+max_dim+mass_radius)
    ax.set_aspect('equal', adjustable="box")

    # Color water
    ax.set_facecolor("xkcd:light blue")

    # Create list of shapes
    patches_and_colors = []

    # Color islands
    for island in environment.islands:
        island = Polygon([
            island.a,
            island.b,
            island.c,
            island.d,
        ])
        patches_and_colors.append((island, 2))

    # Mass circle
    mass = Circle(
        (
            environment.get_x_pos(x).item(),
            environment.get_y_pos(x).item(),
        ),
        mass_radius,
    )
    patches_and_colors.append((mass, 0))

    # Target position circle
    if target_position is not None:
        target_position_shape = Circle(target_position, mass_radius)
        patches_and_colors.append((target_position_shape, 1))

    # Plot shapes
    patches = [p for p, _ in patches_and_colors]
    colors = [c for _, c in patches_and_colors]
    p = PatchCollection(patches, alpha=1.0)
    p.set_array(colors)
    ax.add_collection(p)

    # Save figure
    fig.savefig(output_path)


def plot_animation(
        states: list[torch.Tensor],
        mass_radius: float,
        environment: EnvironmentInstance,
        fps: int,
        target_position: tuple[float, float] | None,
        output_path: Path,
        ):
    b = get_bounds(states, environment, target_position)
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        # Plot final state for an additional second
        for i, state in enumerate(states+[states[-1] for _ in range(fps)]):
            frame_output_path = frame_dir/(f"{i}.png").rjust(10, "0")
            plot_state(
                state,
                x_min=b.x_min,
                x_max=b.x_max,
                y_min=b.y_min,
                y_max=b.y_max,
                mass_radius=mass_radius,
                environment=environment,
                output_path=frame_output_path,
                target_position=target_position,
            )
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
            .output(str(output_path))
            .overwrite_output()
            .run(quiet=True)
        )
