"""Script to visualize StableBaselines3 SWMPO policies on the Salamander
environment."""
import argparse
from pathlib import Path
import gymnasium
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from swmpo.state_machine import deserialize_state_machine
from swmpo.gymnasium_wrapper import DeepSynthWrapper
import salamander_env
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import VecNormalize
import swmpo_experiments.salamander_utils.ground_truth_wrapper


def plot(
        env_id: str,
        state_machine_path: Path,
        policy_path: Path,
        normalization_path: Path,
        output_dir: Path,
        plot_episode_n: int,
        extrinsic_reward_scale: float,
        exploration_window_size: int,
        dt: float,
        ):
    # Load state machine
    state_machine = deserialize_state_machine(state_machine_path)

    # Load policy
    model = SAC.load(policy_path, device="cpu")

    def make_env(animation_output_dir: Path | None = None):
        env = gymnasium.make(
            env_id,
            render_mode=None,
            animation_output_dir=animation_output_dir,
        )
        env = DeepSynthWrapper(
            env=env,
            state_machine=state_machine,
            initial_state_machine_state=0,
            extrinsic_reward_scale=extrinsic_reward_scale,
            exploration_window_size=exploration_window_size,
            dt=dt,
        )
        return env

    for animation_i in range(plot_episode_n):
        animation_output_dir = output_dir/f"{animation_i}"
        animation_output_dir.mkdir()
        env = make_env(animation_output_dir)
        normalized_env = VecNormalize.load(
            load_path=str(normalization_path),
            venv=DummyVecEnv([make_env]),
        )  # This env is only for normalization, it will never be stepped

        max_eval_steps = 2000
        obs, _ = env.reset()
        for _ in range(max_eval_steps):
            obs = normalized_env.normalize_obs(obs)
            action, _ = model.predict(obs, deterministic=True)
            obs, _, terminated, truncated, _ = env.step(action)
            if terminated or truncated:
                env.close()
                normalized_env.close()
                break


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Salamander visualization',
        description='Run SWMPO with a pre-synthesized state machine',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--env_id',
        type=str,
        required=True,
        help='ID of the gymnasium environment.'
    )
    parser.add_argument(
        '--state_machine_zip',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--plot_episode_n',
        type=int,
        required=True,
        help='Number of episodes to plot after training.'
    )
    parser.add_argument(
        '--extrinsic_reward_scale',
        type=float,
        required=True,
        help='Extrinsic reward constant for the DeepSynth wrapper.',
    )
    parser.add_argument(
        '--exploration_window_size',
        type=int,
        required=True,
        help='Size of the reward window for the exploration reward.',
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    parser.add_argument(
        '--policy_zip',
        type=Path,
        required=True,
        help='StableBaselines3 policy ZIP file'
    )
    parser.add_argument(
        '--normalization_zip',
        type=Path,
        required=True,
        help='StableBaselines3 normalization ZIP file'
    )
    args = parser.parse_args()
    output_dir = args.output_dir
    output_dir.mkdir()
    plot(
        env_id=args.env_id,
        state_machine_path=args.state_machine_zip,
        policy_path=args.policy_zip,
        normalization_path=args.normalization_zip,
        output_dir=output_dir,
        plot_episode_n=args.plot_episode_n,
        extrinsic_reward_scale=args.extrinsic_reward_scale,
        exploration_window_size=args.exploration_window_size,
        dt=args.dt,
    )
