"""Helper function to benchmark a `stable_baselines3.common.vec_env.VecNormalize`
with the RL algorithm."""
import gymnasium
from stable_baselines3 import SAC
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.callbacks import EvalCallback
from gymnasium.utils.save_video import save_video
from pathlib import Path
import random

import terrain_mass
import autonomous_car_verification
import salamander_env
import bipedal_walker_hardcore_modes
import swmpo_experiments.autonomous_driving_utils.ground_truth_wrapper
import swmpo_experiments.terrain_mass_utils.ground_truth_wrapper
import swmpo_experiments.salamander_utils.ground_truth_wrapper
import swmpo_experiments.bipedal_walker_hardcore_utils.ground_truth_wrapper


# Avoid pytorch from doing threading. This is so that the script doesn't
# take over the computer's resources. You can remove these lines if not running
# on a lab computer.
import torch
torch.set_num_threads(1)


def benchmark_venv(
        train_timestep_n: int,
        train_env: VecEnv,
        eval_env: VecEnv,
        plot_env: gymnasium.Env | None,
        eval_freq: int,
        eval_episode_n: int,
        output_dir: Path,
        plot_episode_n: int,
        cuda_device: str,
        seed: str,
        ):
    # Wrap environments
    train_env = VecNormalize(
        venv=train_env,
        training=True,
        norm_obs=True,
        norm_reward=False,
    )
    eval_env = VecNormalize(
        venv=eval_env,
        training=False,
        norm_obs=True,
        norm_reward=False,
    )

    # Optimize policy
    _random = random.Random(seed)
    model = SAC(
        "MlpPolicy",
        train_env,
        verbose=1,
        #use_sde=False,
        seed=int.from_bytes(_random.randbytes(3), 'big', signed=False),
        #n_steps=2, # PPO
        #train_freq=(1, "episode"),
        gradient_steps=-1,  # SAC
        device=cuda_device,
        #use_sde=True,
    )
    log_path = output_dir/"log"
    new_logger = configure(str(log_path), ["stdout", "csv", "tensorboard"])
    model.set_logger(new_logger)
    eval_callback = EvalCallback(
        eval_env,
        eval_freq=eval_freq,
        n_eval_episodes=eval_episode_n,
        deterministic=True,
        render=False,
        verbose=True,
    )
    model.learn(
        total_timesteps=train_timestep_n,
        log_interval=100,
        callback=eval_callback,
        progress_bar=False,
    )

    # Serialize policy
    output_policy_path = output_dir/"RL_SB3_policy.zip"
    output_normalization_path = output_dir/"SB3_normalized_venv.zip"
    model.save(output_policy_path.with_suffix(''))
    train_env.save(str(output_normalization_path))
    print(f"Wrote {output_policy_path}")
    print(f"Wrote {output_normalization_path}")
    del model

    # Load policy and normalization statistics
    # This environment is only for normalization, it will not be step'd
    model = SAC.load(output_policy_path, device="cpu")
    normalized_env = VecNormalize.load(
        load_path=str(output_normalization_path),
        venv=train_env,
    )

    # Save animation
    if plot_env is None:
        return
    obs, _ = plot_env.reset()
    step_starting_index = 0
    episode_index = 0
    max_eval_steps = 1_000_000
    video_path = output_dir/"videos"
    frames = list()
    frames.append(plot_env.render())
    for step_index in range(max_eval_steps):
        obs = normalized_env.normalize_obs(obs)
        action, _ = model.predict(obs, deterministic=True)
        obs, _, terminated, truncated, _ = plot_env.step(action)
        frames.append(plot_env.render())
        if terminated or truncated:
            save_video(
                frames,
                str(video_path),
                fps=plot_env.metadata["render_fps"],
                step_starting_index=step_starting_index,
                episode_index=episode_index
            )
            step_starting_index = step_index + 1
            episode_index += 1
            frames = list()
            plot_env.reset()

            if episode_index >= plot_episode_n:
                break
    if len(frames) > 0:
        save_video(
            frames,
            str(video_path),
            fps=plot_env.metadata["render_fps"],
            step_starting_index=step_starting_index,
            episode_index=episode_index
        )
