"""Compare different models."""
from pathlib import Path
from swmpo_experiments.terrain_mass_utils.collect_data import sample_task_episodes
from terrain_mass.task import get_example_task
import argparse
import random


def main(
        output_dir: Path,
        episode_n: int,
        worker_n: int,
        mpc_iter_n: int,
        mpc_plan_len: int,
        simulation_step_n: int,
        mpc_initial_stdev: float,
        seed: str,
        animation_fps: int,
        ):
    # Define environment
    _random = random.Random(seed)
    tasks = [
        get_example_task(str(_random.random()))
        for _ in range(episode_n)
    ]

    # Collect training data
    print("Collecting data...")
    episode_n = episode_n
    sample_task_episodes(
        tasks=tasks,
        worker_n=worker_n,
        seed=seed,
        mpc_iter_n=mpc_iter_n,
        mpc_plan_len=mpc_plan_len,
        simulation_step_n=simulation_step_n,
        mpc_initial_stdev=mpc_initial_stdev,
        animation_fps=animation_fps,
        output_dir=output_dir,
    )
    print("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Generate trajectories.',
        description='Generate a dataset of trajectories from the terrain mass environment using an expert policy.',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--episode_n',
        type=int,
        required=True,
        help='Number of episodes to synthesize the state machine'
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of workers to sample episodes',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed'
    )
    parser.add_argument(
        '--mpc_plan_len',
        type=int,
        required=True,
        help='Length of the MPC local plan.',
    )
    parser.add_argument(
        '--mpc_initial_stdev',
        type=float,
        required=True,
        help='Standard deviation for the MPC optimization algorithm.',
    )
    parser.add_argument(
        '--mpc_iter_n',
        type=int,
        required=True,
        help='Maximum number of iterations for the MPC loop.'
    )
    parser.add_argument(
        '--simulation_step_n',
        type=int,
        required=True,
        help='Maximum number of steps for the simulations.'
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        output_dir=args.output_dir,
        episode_n=args.episode_n,
        seed=args.seed,
        worker_n=args.worker_n,
        mpc_iter_n=args.mpc_iter_n,
        mpc_plan_len=args.mpc_plan_len,
        mpc_initial_stdev=args.mpc_initial_stdev,
        simulation_step_n=args.simulation_step_n,
        animation_fps=60,
    )
