"""Benchmark SWMPO by augmenting the MDP with the state machine."""
import gymnasium
import terrain_mass
import autonomous_car_verification
import salamander_env
from swmpo.state_machine import deserialize_state_machine
from swmpo.gymnasium_wrapper import DeepSynthWrapper
from stable_baselines3.common.env_util import make_vec_env
from swmpo_experiments.benchmark_venv import benchmark_venv
from pathlib import Path
import argparse
import random


def perform_experiment(
        env_id: str,
        state_machine_path: Path,
        output_dir: Path,
        seed: str,
        datapoint_n: int,
        train_timestep_n: int,
        plot_episode_n: int,
        parallel_env_n: int,
        eval_episode_n: int,
        extrinsic_reward_scale: float,
        exploration_window_size: int,
        cuda_device: str,
        dt: float,
        ):
    _random = random.Random(seed)

    # Load state machine
    state_machine = deserialize_state_machine(state_machine_path)

    # Create train and evaluation environments
    def make_env():
        return make_vec_env( env_id=env_id,
            n_envs=parallel_env_n,
            seed=int.from_bytes(_random.randbytes(3), 'big', signed=False),
            env_kwargs=dict(
                render_mode=None,
            ),
            wrapper_class=DeepSynthWrapper,
            wrapper_kwargs=dict(
                state_machine=state_machine,
                initial_state_machine_state=0,
                extrinsic_reward_scale=extrinsic_reward_scale,
                exploration_window_size=exploration_window_size,
                dt=dt,  # from the state machine synthesis script
            ),
        )
    train_env = make_env()
    eval_env = make_env()

    # Create plotting environment
    if plot_episode_n > 0:
        plot_env = DeepSynthWrapper(
            env=gymnasium.make(env_id, render_mode="rgb_array"),
            state_machine=state_machine,
            initial_state_machine_state=0,
            extrinsic_reward_scale=extrinsic_reward_scale,
            exploration_window_size=exploration_window_size,
            dt=dt,
        )
    else:
        plot_env = None

    # Decide how often to evaluate the policy
    eval_freq = max(1, train_timestep_n//parallel_env_n//datapoint_n)

    # Benchmark RL on environment
    benchmark_venv(
        plot_env=plot_env,
        train_timestep_n=train_timestep_n,
        train_env=train_env,
        eval_freq=eval_freq,
        eval_env=eval_env,
        eval_episode_n=eval_episode_n,
        plot_episode_n=plot_episode_n,
        cuda_device=cuda_device,
        output_dir=output_dir,
        seed=str(_random.random()),
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='SWMPO benchmark',
        description='Run SWMPO with a pre-synthesized state machine',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--env_id',
        type=str,
        required=True,
        help='ID of the gymnasium environment.'
    )
    parser.add_argument(
        '--state_machine_zip',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--train_timestep_n',
        type=int,
        required=True,
        help='Number of gradient descent iterations for the state machine optimization process'
    )
    parser.add_argument(
        '--datapoint_n',
        type=int,
        required=True,
        help=(
            "Number of datapoints. A datapoint is a measurement of values for"
            " a particular training timestep. A value of `x` means that the"
            " policy will be evaluated every `train_timestep_n//x` timesteps."
        ),
    )
    parser.add_argument(
        '--plot_episode_n',
        type=int,
        required=True,
        help='Number of episodes to plot after training.'
    )
    parser.add_argument(
        '--eval_episode_n',
        type=int,
        required=True,
        help='Number of episodes to evaluate policies on during training.',
    )
    parser.add_argument(
        '--exploration_window_size',
        type=int,
        required=True,
        help='Size of the reward window for the exploration reward.',
    )
    parser.add_argument(
        '--extrinsic_reward_scale',
        type=float,
        required=True,
        help='Extrinsic reward constant for the DeepSynth wrapper.',
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    parser.add_argument(
        '--parallel_env_n',
        type=int,
        required=True,
        help='Number of parallel environments in each worker.'
    )
    parser.add_argument(
        '--cuda_device',
        type=str,
        required=True,
        help='CUDA device for SGD optimization',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed'
    )
    args = parser.parse_args()
    output_dir = args.output_dir
    output_dir.mkdir()
    perform_experiment(
        env_id=args.env_id,
        state_machine_path=args.state_machine_zip,
        output_dir=output_dir,
        seed=args.seed,
        datapoint_n=args.datapoint_n,
        train_timestep_n=args.train_timestep_n,
        plot_episode_n=args.plot_episode_n,
        parallel_env_n=args.parallel_env_n,
        eval_episode_n=args.eval_episode_n,
        extrinsic_reward_scale=args.extrinsic_reward_scale,
        cuda_device=args.cuda_device,
        exploration_window_size=args.exploration_window_size,
        dt=args.dt,
    )
