"""Plotting utilities for the ant environment."""
from environment import get_simulation
from dynamics import simulation_
import tempfile
from pathlib import Path
import PIL.Image
import ffmpeg
import torch
from rl_baseline import TightlyGuidedAntPasswordEnv


def frames_to_video(frames, fps, output_path):
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        for i, frame in enumerate(frames):
            path = frame_dir/(f"{i}.png").rjust(10)
            img = PIL.Image.fromarray(frame)
            img.save(path)
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
            .output(str(output_path))
            .overwrite_output()
            .run(quiet=True)
        )


def plot_animation(
        parameters: torch.Tensor,
        fps: int,
        output_path: Path,
        password: tuple[int, ...],
        num_buttons: int,
        sub_step_s: float,
        ):
    """Run an ant simulation using the given parameters as control signal. The
    `parameters` define the time-varying control signal at each step and is
    expected to be of shape `(num_timesteps, 8)`."""
    # Initialize the environment
    env = TightlyGuidedAntPasswordEnv(
        target_password=password,
        button_n=num_buttons,
        actuator_n=8,
        episode_timestep_n=10000000,
        sub_step_s=sub_step_s,
        password_so_far_encoding_size=5,
    )
    video = list()
    for t in range(len(parameters)):
        action = parameters[t]
        _, _, done, _ = env.step(action)
        video.append(env.render())
        if done:
            break

    # Save each frame
    frames = video+[video[-1] for _ in range(fps)]
    frames_to_video(frames, fps, output_path)
