"""Benchmark RL. This is the same as `benchmark_rl.py` except that it
calls `benchmark_venv_mode_bias.py` as the training script, instead of
`benchmark_venv.py`."""
import gymnasium
from stable_baselines3.common.env_util import make_vec_env
from swmpo_experiments.benchmark_venv_mode_bias import benchmark_venv
from pathlib import Path
import argparse
import random


def perform_experiment(
        env_id: str,
        output_dir: Path,
        train_timestep_n: int,
        datapoint_n: int,
        plot_episode_n: int,
        seed: str,
        parallel_env_n: int,
        eval_episode_n: int,
        cuda_device: str,
        ):
    _random = random.Random(seed)

    # Create train and evaluation environments
    def make_env():
        return make_vec_env(
            env_id=env_id,
            n_envs=parallel_env_n,
            seed=int.from_bytes(_random.randbytes(3), 'big', signed=False),
            env_kwargs=dict(
                render_mode=None,
            )
        )
    train_env = make_env()
    eval_env = make_env()

    # Create plotting environment
    if plot_episode_n > 0:
        plot_env = gymnasium.make(env_id, render_mode="rgb_array")
    else:
        plot_env = None

    # Decide how often to evaluate the policy
    eval_freq = max(1, train_timestep_n//parallel_env_n//datapoint_n)

    # Benchmark RL on environment
    benchmark_venv(
        plot_env=plot_env,
        train_timestep_n=train_timestep_n,
        train_env=train_env,
        eval_env=eval_env,
        eval_freq=eval_freq,
        eval_episode_n=eval_episode_n,
        plot_episode_n=plot_episode_n,
        cuda_device=cuda_device,
        output_dir=output_dir,
        seed=str(_random.random()),
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='RL benchmark',
        description='Run RL',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--train_timestep_n',
        type=int,
        required=True,
        help='Number of gradient descent iterations for the state machine optimization process'
    )
    parser.add_argument(
        '--env_id',
        type=str,
        required=True,
        help='ID of the gymnasium environment.'
    )
    parser.add_argument(
        '--datapoint_n',
        type=int,
        required=True,
        help=(
            "Number of datapoints. A datapoint is a measurement of values for"
            " a particular training timestep. A value of `x` means that the"
            " policy will be evaluated every `train_timestep_n//x` timesteps."
        ),
    )
    parser.add_argument(
        '--plot_episode_n',
        type=int,
        required=True,
        help='Number of episodes to plot after training.'
    )
    parser.add_argument(
        '--parallel_env_n',
        type=int,
        required=True,
        help='Number of parallel environments in each worker.'
    )
    parser.add_argument(
        '--eval_episode_n',
        type=int,
        required=True,
        help='Number of episodes to evaluate policies on during training.',
    )
    parser.add_argument(
        '--cuda_device',
        type=str,
        required=True,
        help='CUDA device for SGD optimization',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed'
    )
    args = parser.parse_args()
    output_dir = args.output_dir
    output_dir.mkdir()
    perform_experiment(
        env_id=args.env_id,
        output_dir=output_dir,
        seed=args.seed,
        train_timestep_n=args.train_timestep_n,
        datapoint_n=args.datapoint_n,
        plot_episode_n=args.plot_episode_n,
        parallel_env_n=args.parallel_env_n,
        eval_episode_n=args.eval_episode_n,
        cuda_device=args.cuda_device,
    )
