"""
Train a VPS-based option set on an Atari or MiniGrid environment and
visualize each learned option by rolling it out and saving a short GIF.

This script:
- Trains `ContinuousVPSAgent` in three stages (Value → VPS → Option Q).
- Saves all trained networks to a checkpoint file.
- Replays each option for a fixed number of steps and writes GIFs.
"""

import os, time, imageio, torch
import numpy as np
import gymnasium as gym
import argparse
from continuous_vps_agent import ContinuousVPSAgent
from bottleneck_env import SimpleEnv

class MaxAndSkipEnv(gym.Wrapper):
    """Repeat the same action and max-pool the last two frames.

    For each step(action):
    - execute the same action for `skip` frames
    - accumulate rewards across those frames
    - return obs = max(last_obs, second_last_obs)
    """

    def __init__(self, env: gym.Env, skip: int = 4):
        super().__init__(env)
        if skip < 1:
            raise ValueError(f"skip must be >= 1, got {skip}")
        self._skip = int(skip)
        self._obs_buffer = None

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        # Allocate buffer lazily with correct dtype/shape
        self._obs_buffer = np.zeros((2,) + obs.shape, dtype=obs.dtype)
        self._obs_buffer[0] = obs
        self._obs_buffer[1] = obs
        return obs, info

    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False
        info = {}

        obs = None
        for i in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += float(reward)
            if terminated or truncated:
                break

        # If we terminated early, ensure buffers are filled
        if obs is not None and self._obs_buffer is not None:
            if self._skip == 1:
                self._obs_buffer[0] = obs
                self._obs_buffer[1] = obs
            elif i < self._skip - 2:
                self._obs_buffer[0] = obs
                self._obs_buffer[1] = obs
            elif i == self._skip - 2:
                self._obs_buffer[1] = obs

        max_frame = np.maximum(self._obs_buffer[0], self._obs_buffer[1])
        return max_frame, total_reward, terminated, truncated, info

games = {
    'montezuma': 'ALE/MontezumaRevenge-v5',
    'venture': 'ALE/Venture-v5',
    'pong': 'ALE/Pong-v5',
    'solaris': 'ALE/Solaris-v5',
    'freeway': 'ALE/Freeway-v5',
    'gravitar': 'ALE/Gravitar-v5',
    'adventure': 'ALE/Adventure-v5',
    'private': 'ALE/PrivateEye-v5',
    'pitfall': 'ALE/Pitfall-v5',
    'breakout': 'ALE/Breakout-v5',
    'pacman': 'ALE/MsPacman-v5',
    'gridworld': None,
}

parser = argparse.ArgumentParser(description='Implementation for VPS options')
parser.add_argument('--game',          type=str,   default='gridworld', help='Game name')
parser.add_argument('--buffer_size',   type=int,   default=10000,    help='Replay-buffer capacity')
parser.add_argument('--value_iters',   type=int,   default=500000,    help='Value-network iterations')
parser.add_argument('--vps_iters',     type=int,   default=500000,    help='VPS-network iterations')
parser.add_argument('--option_iters',  type=int,   default=500000,    help='Option-DQN iterations')
parser.add_argument('--num_options',   type=int,   default=4,          help='Number of options')
parser.add_argument('--frame_stack',   type=int,   default=1,          help='Frames per stack')
parser.add_argument('--batch_size',    type=int,   default=128,        help='Training batch size')
parser.add_argument('--frame_skip',    type=int,   default=4,          help='Atari action repeat (max-pool last two frames)')
args = parser.parse_args()

DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Current device: {DEVICE}')
ENV_ID        = games[args.game.lower()]
NUM_OPTIONS   = args.num_options
SAVE_DIR      = os.path.join(args.game.lower(), 'outputs')
os.makedirs(SAVE_DIR, exist_ok=True)

# ---------------- Training ----------------
env = SimpleEnv(render_mode='rgb_array', highlight=False) \
      if args.game.lower() == "gridworld" \
      else MaxAndSkipEnv(gym.make(ENV_ID), skip=args.frame_skip)

agent = ContinuousVPSAgent(
    env,
    gamma_v=0.99,
    gamma_q=0.9,
    k_options=NUM_OPTIONS,
    device=DEVICE,
    buffer_cap=args.buffer_size,
    frame_stack_len=args.frame_stack,
    batch_size=args.batch_size,
    max_episode_length=500,
)

agent.train(
    vps_iters=args.vps_iters,
    value_iters=args.value_iters,
    option_iters=args.option_iters,
    print_every=50,
)

agent.save_all(os.path.join(SAVE_DIR, "networks.pt"))

# --------------- Visualization ------------
@torch.no_grad()
def rollout_option(opt_id: int, n_steps: int = 100):
    """Roll out one option head for up to ``n_steps`` and save a GIF.

    The function builds a fresh rendering env to avoid reusing internal
    GUI resources, and handles environments with or without a frame-diff
    input channel for the option Q-network.
    """
    env_vis = SimpleEnv(render_mode="rgb_array", highlight=False) \
              if args.game.lower() == "gridworld" \
              else MaxAndSkipEnv(gym.make(ENV_ID, render_mode="rgb_array"), skip=args.frame_skip)

    obs, _ = env_vis.reset(seed=int(np.random.randint(1_000_000)))
    frames = []

    agent.reset_frame_stack()
    state4 = agent.obs_to_state(obs)        # (F,84,84)

    # Determine whether the Q-network expects a frame-difference channel
    diff_enabled = agent.backbone_q.net[0].weight.shape[1] > agent.F

    def add_diff(st):
        diff = (st[-1] - st[-2]).unsqueeze(0)     # (1,84,84)
        return torch.cat([st, diff], 0)           # (F+1,84,84)

    state = add_diff(state4) if diff_enabled else state4

    for _ in range(n_steps):
        frames.append(env_vis.render())

        feat = agent.backbone_q(state.unsqueeze(0))
        q    = agent.opt_heads[opt_id](feat)
        act  = int(q.argmax().item())

        next_obs, _, term, trunc, _ = env_vis.step(act)
        state4 = agent.obs_to_state(next_obs)
        state  = add_diff(state4) if diff_enabled else state4
        if term or trunc:
            break

    gif_path = os.path.join(SAVE_DIR, f"option_{opt_id:02d}.gif")
    imageio.mimsave(gif_path, frames, fps=150)
    print(f"[✓] Saved roll-out → {gif_path}")
    env_vis.close()


print("\n=== Visualizing options … ===")
for k in range(NUM_OPTIONS):
    rollout_option(k)

env.close()
