"""
Train a VPS-based option set on an Atari or MiniGrid environment and
visualize each learned option by rolling it out and saving a short GIF.

This script:
- Trains `ContinuousVPSAgent` in three stages (Value → VPS → Option Q).
- Saves all trained networks to a checkpoint file.
- Replays each option for a fixed number of steps and writes GIFs.
"""

import os, time, imageio, torch
import numpy as np
import gymnasium as gym
import argparse
from continuous_vps_agent import ContinuousVPSAgent
from bottleneck_env import SimpleEnv

games = {
    'montezuma': 'ALE/MontezumaRevenge-v5',
    'venture': 'ALE/Venture-v5',
    'pong': 'ALE/Pong-v5',
    'solaris': 'ALE/Solaris-v5',
    'freeway': 'ALE/Freeway-v5',
    'gravitar': 'ALE/Gravitar-v5',
    'adventure': 'ALE/Adventure-v5',
    'private': 'ALE/PrivateEye-v5',
    'pitfall': 'ALE/Pitfall-v5',
    'breakout': 'ALE/Breakout-v5',
    'pacman': 'ALE/MsPacman-v5',
    'gridworld': None,
}

parser = argparse.ArgumentParser(description='Implementation for VPS options')
parser.add_argument('--game',          type=str,   default='gravitar', help='Game name')
parser.add_argument('--buffer_size',   type=int,   default=200,    help='Replay-buffer capacity')
parser.add_argument('--value_iters',   type=int,   default=200,    help='Value-network iterations')
parser.add_argument('--vps_iters',     type=int,   default=200,    help='VPS-network iterations')
parser.add_argument('--option_iters',  type=int,   default=200,    help='Option-DQN iterations')
parser.add_argument('--num_options',   type=int,   default=8,          help='Number of options')
parser.add_argument('--frame_stack',   type=int,   default=1,          help='Frames per stack')
parser.add_argument('--batch_size',    type=int,   default=128,        help='Training batch size')
args = parser.parse_args()

DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Current device: {DEVICE}')
ENV_ID        = games[args.game.lower()]
NUM_OPTIONS   = args.num_options
SAVE_DIR      = os.path.join(args.game.lower(), 'outputs')
os.makedirs(SAVE_DIR, exist_ok=True)

# ---------------- Training ----------------
env = SimpleEnv(render_mode='rgb_array', highlight=False) \
      if args.game.lower() == "gridworld" \
      else gym.make(ENV_ID)

agent = ContinuousVPSAgent(
    env,
    gamma_v=0.99,
    gamma_q=0.9,
    k_options=NUM_OPTIONS,
    device=DEVICE,
    buffer_cap=args.buffer_size,
    frame_stack_len=args.frame_stack,
    batch_size=args.batch_size,
    max_episode_length=500,
)

agent.train(
    vps_iters=args.vps_iters,
    value_iters=args.value_iters,
    option_iters=args.option_iters,
)

agent.save_all(os.path.join(SAVE_DIR, "networks.pt"))

# --------------- Visualization ------------
@torch.no_grad()
def rollout_option(opt_id: int, n_steps: int = 100):
    """Roll out one option head for up to ``n_steps`` and save a GIF.

    The function builds a fresh rendering env to avoid reusing internal
    GUI resources, and handles environments with or without a frame-diff
    input channel for the option Q-network.
    """
    env_vis = SimpleEnv(render_mode="rgb_array", highlight=False) \
              if args.game.lower() == "gridworld" \
              else gym.make(ENV_ID, render_mode="rgb_array")

    obs, _ = env_vis.reset(seed=int(np.random.randint(1_000_000)))
    frames = []

    agent.reset_frame_stack()
    state4 = agent.obs_to_state(obs)        # (F,84,84)

    # Determine whether the Q-network expects a frame-difference channel
    diff_enabled = agent.backbone_q.net[0].weight.shape[1] > agent.F

    def add_diff(st):
        diff = (st[-1] - st[-2]).unsqueeze(0)     # (1,84,84)
        return torch.cat([st, diff], 0)           # (F+1,84,84)

    state = add_diff(state4) if diff_enabled else state4

    for _ in range(n_steps):
        frames.append(env_vis.render())

        feat = agent.backbone_q(state.unsqueeze(0))
        q    = agent.opt_heads[opt_id](feat)
        act  = int(q.argmax().item())

        next_obs, _, term, trunc, _ = env_vis.step(act)
        state4 = agent.obs_to_state(next_obs)
        state  = add_diff(state4) if diff_enabled else state4
        if term or trunc:
            break

    gif_path = os.path.join(SAVE_DIR, f"option_{opt_id:02d}.gif")
    imageio.mimsave(gif_path, frames, fps=150)
    print(f"[✓] Saved roll-out → {gif_path}")
    env_vis.close()


print("\n=== Visualizing options … ===")
for k in range(NUM_OPTIONS):
    rollout_option(k)

env.close()
