"""Script to visualize StableBaselines3 vanilla-RL policies on the Salamander
environment."""
import argparse
from pathlib import Path
import gymnasium
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from swmpo.state_machine import deserialize_state_machine
from swmpo.gymnasium_wrapper import DeepSynthWrapper
import salamander_env
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import VecNormalize
import swmpo_experiments.salamander_utils.ground_truth_wrapper


def plot(
        env_id: str,
        policy_path: Path,
        normalization_path: Path,
        output_dir: Path,
        plot_episode_n: int,
        ):
    # Load policy
    model = SAC.load(policy_path, device="cpu")

    def make_env(animation_output_dir: Path | None = None):
        env = gymnasium.make(
            env_id,
            render_mode=None,
            animation_output_dir=animation_output_dir,
        )
        return env

    for animation_i in range(plot_episode_n):
        animation_output_dir = output_dir/f"{animation_i}"
        animation_output_dir.mkdir()
        env = make_env(animation_output_dir)
        normalized_env = VecNormalize.load(
            load_path=str(normalization_path),
            venv=DummyVecEnv([make_env]),
        )

        max_eval_steps = 2000
        obs, _ = env.reset()
        for t in range(max_eval_steps):
            obs = normalized_env.normalize_obs(obs)
            action, _ = model.predict(obs, deterministic=True)
            obs, _, terminated, truncated, _ = env.step(action)
            if terminated or truncated:
                print(f"Visualization episode lasted {t} steps")
                break
        env.close()
        normalized_env.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Salamander visualization',
        description='Run SWMPO with a pre-synthesized state machine',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--env_id',
        type=str,
        required=True,
        help='ID of the gymnasium environment.'
    )
    parser.add_argument(
        '--plot_episode_n',
        type=int,
        required=True,
        help='Number of episodes to plot after training.'
    )
    parser.add_argument(
        '--policy_zip',
        type=Path,
        required=True,
        help='StableBaselines3 policy ZIP file'
    )
    parser.add_argument(
        '--normalization_zip',
        type=Path,
        required=True,
        help='StableBaselines3 normalization ZIP file'
    )
    args = parser.parse_args()
    output_dir = args.output_dir
    output_dir.mkdir()
    plot(
        env_id=args.env_id,
        policy_path=args.policy_zip,
        normalization_path=args.normalization_zip,
        output_dir=output_dir,
        plot_episode_n=args.plot_episode_n,
    )
