"""Trajectory dataset generation for the `BipedalWalkerHardcore`
environment."""
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium
import bipedal_walker_hardcore_modes
import ffmpeg
from pathlib import Path
import argparse
import random
import tempfile
from PIL import Image
import torch
from swmpo.transition import Transition
from swmpo_experiments.trajectory_dataset import serialize_trajectory
from concurrent.futures import ProcessPoolExecutor


def generate_trajectory(
    output_dir: Path,
    expert_policy_dir: Path,
    seed: str,
):
    """Generates and serializes a trajectory to the given output directory."""
    EXPERT_POLICY_NORMALIZATION_PATH = expert_policy_dir/"BipedalWalkerHardcore-v3"/"vecnormalize.pkl"
    EXPERT_POLICY_PATH = expert_policy_dir/"BipedalWalkerHardcore-v3.zip"
    assert EXPERT_POLICY_PATH.exists(), f"Path '{EXPERT_POLICY_PATH.resolve()}' doesnt exist!"

    # Initialize environment
    env = gymnasium.make(
        "BipedalWalkerHardcoreModes-v3",
        render_mode="rgb_array",
    )
    # Load normalization statistics
    # _venv will only be used to normalize observations
    _venv = DummyVecEnv([
        lambda: gymnasium.make(
            "BipedalWalkerHardcoreModes-v3",
            render_mode="rgb_array",
        )])
    _venv = VecNormalize.load(EXPERT_POLICY_NORMALIZATION_PATH, _venv)

    # Load expert policy
    model = PPO.load(str(EXPERT_POLICY_PATH))

    # Iterate policy
    obs, info = env.reset()
    frames = list()
    ground_truth_modes = list()
    transitions = list()
    while True:
        obs = _venv.normalize_obs(obs)
        action, _ = model.predict(obs, deterministic=True)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # Save transition and mode
        source_state = torch.from_numpy(obs)
        next_state = torch.from_numpy(next_obs)
        action = torch.from_numpy(action)
        transition = Transition(
            source_state=source_state,
            action=action,
            next_state=next_state,
        )
        transitions.append(transition)
        ground_truth_modes.append(info["ground_truth_mode"])

        # Save frame
        frames.append(env.render())

        obs = next_obs

        if terminated or truncated:
            break

    # Serialize transitions
    serialize_trajectory(
        transitions=transitions,
        ground_truth_modes=ground_truth_modes,
        output_dir=output_dir,
    )

    # Serialize trajectory visualization
    video_output_path = output_dir/"trajectory.mp4"
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        for i, frame in enumerate(frames):
            frame_path = frame_dir/(f"{i}.png").rjust(10)
            im = Image.fromarray(frame)
            im.save(frame_path)
        (
            ffmpeg
            .input(
                frame_dir/"*.png",
                pattern_type="glob",
                framerate=env.metadata["render_fps"],
            )
            .output(str(video_output_path))
            .run(quiet=True)
        )
        print(f"Wrote {video_output_path}")


def main(
    output_dir: Path,
    trajectory_n: int,
    seed: str,
    worker_n: int,
    expert_policy_dir: Path,
):
    """Generate the given number of trajectories to the given output
    directory.

    The output directory is assumed to exist."""
    # TODO: parallelize
    _random = random.Random(seed)

    with ProcessPoolExecutor(worker_n) as executor:
        futures = list()
        for i in range(trajectory_n):
            episode_dir = output_dir/f"{i}"
            episode_dir.mkdir()
            future = executor.submit(
                generate_trajectory,
                output_dir=episode_dir,
                seed=str(_random.random()),
                expert_policy_dir=expert_policy_dir,
            )
            futures.append(future)

        for i, future in enumerate(futures):
            future.result()
            print(f"Done {i+1}/{len(futures)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Generate trajectories.',
        description='Generate a dataset of trajectories from the `BipedalWalkerHardcore` environment using an expert policy.',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--trajectory_n',
        type=int,
        required=True,
        help='Number of episodes to synthesize the state machine'
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of workers to sample episodes',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed'
    )
    parser.add_argument(
        '--expert_policy_dir',
        type=Path,
        required=True,
        help='Path to a StableBaselines3 PPO BipedalWalkerHardcore policy with normalization statistics.'
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        output_dir=args.output_dir,
        trajectory_n=args.trajectory_n,
        seed=args.seed,
        worker_n=args.worker_n,
        expert_policy_dir=args.expert_policy_dir,
    )
