#!/usr/bin/env python3
"""Plot the total VPS signal along Atari trajectories and save peak frames.

Loads a trained VPS network, samples frames from an Atari env (manual
keyboard control optional), computes Σ_i φ_i(s) across options, plots the
curve with an inset for overflow peaks, and saves local peak frames.
"""
from __future__ import annotations
import argparse, os, time, warnings
from pathlib import Path
import numpy as np
import torch, gymnasium as gym
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from tqdm import tqdm

# ---------- your own network imports ----------
from networks             import AtariCNN, VPSNet
from continuous_vps_agent import ContinuousVPSAgent

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
warnings.filterwarnings('ignore', category=UserWarning)

# ---------- keyboard support ----------
try:
    import pygame
    _HAS_PYGAME = True
except ImportError:
    _HAS_PYGAME = False

GAMES = {
    'montezuma': 'ALE/MontezumaRevenge-v5',
    'venture':   'ALE/Venture-v5',
    'pong':      'ALE/Pong-v5',
    'solaris':   'ALE/Solaris-v5',
    'freeway':   'ALE/Freeway-v5',
    'gravitar':  'ALE/Gravitar-v5',
    'adventure': 'ALE/Adventure-v5',
    'private':   'ALE/PrivateEye-v5',
    'pitfall':   'ALE/Pitfall-v5',
    'breakout':  'ALE/Breakout-v5',
    'pacman':    'ALE/MsPacman-v5',
}

# ---------------- CLI ----------------
pa = argparse.ArgumentParser()
pa.add_argument('--ckpt',   default='./freeway/outputs/networks.pt')
pa.add_argument('--env',    choices=GAMES.keys(), default='freeway')
pa.add_argument('--steps',  type=int, default=1000)
pa.add_argument('--win',    type=int, default=100)
pa.add_argument('--out',    default='vps_sum_curve')
pa.add_argument('--device', default='cuda')
pa.add_argument('--manual', action='store_true', default=True)
pa.add_argument('--fps_print', type=int, default=200)
pa.add_argument('--ylim',   type=float, default=20,
                help='y-lim in main plot; overflow values are shown in the inset')
args = pa.parse_args()

if args.manual and not _HAS_PYGAME:
    raise RuntimeError('Install pygame to use --manual')

ENV_ID   = GAMES[args.env]
OUT_DIR  = Path(args.out) / args.env
PEAK_DIR = OUT_DIR / 'peaks'
PEAK_DIR.mkdir(parents=True, exist_ok=True)
dev = torch.device(args.device)

# ============================================================
# 1. Networks
# ============================================================
ckpt      = torch.load(args.ckpt, map_location=dev)
k_opt     = ckpt['k']
stack_len = ckpt['frame_stack_len']
backbone  = AtariCNN(in_channels=stack_len).to(dev).eval()
backbone.load_state_dict(ckpt['backbone_vps'])
vps_net   = VPSNet(k_opt, backbone).to(dev).eval()
vps_net.head.load_state_dict(ckpt['vps_head'])

dummy = ContinuousVPSAgent(
    gym.make(ENV_ID, render_mode='rgb_array', frameskip=3),
    k_options=k_opt,
    frame_stack_len=stack_len,
    device=args.device,
    buffer_cap=1,
)

@torch.no_grad()
def rgb_state(rgb):  # RGB frame → stacked & pre-processed tensor
    return dummy.obs_to_state(rgb).unsqueeze(0)

# ============================================================
# 2. Environments
# ============================================================
env_rgb = gym.make(
    ENV_ID, render_mode='rgb_array', frameskip=3,
    repeat_action_probability=0.0
)
env_vis = gym.make(
    ENV_ID, render_mode='human', frameskip=3,
    repeat_action_probability=0.0
)
env_rgb.reset()
env_vis.reset()
dummy.reset_frame_stack()

# Keyboard mapping (for manual mode)
if args.manual:
    pygame.init()
KEY2BTN = {
    pygame.K_k: 0,
    pygame.K_SPACE: 1,
    pygame.K_f: 1,
    pygame.K_UP: 2,
    pygame.K_RIGHT: 3,
    pygame.K_LEFT: 4,
    pygame.K_DOWN: 5,
    'LEFT+FIRE': 15,
    'RIGHT+FIRE': 14,
    'UP+FIRE': 12,
}

def poll():
    for e in pygame.event.get():
        if e.type == pygame.QUIT:
            return None
    p = pygame.key.get_pressed()
    if p[pygame.K_SPACE] or p[pygame.K_f]:
        if p[pygame.K_LEFT]:
            return KEY2BTN['LEFT+FIRE']
        if p[pygame.K_RIGHT]:
            return KEY2BTN['RIGHT+FIRE']
        if p[pygame.K_UP]:
            return KEY2BTN['UP+FIRE']
        return 1
    if p[pygame.K_LEFT]:
        return 4
    if p[pygame.K_RIGHT]:
        return 3
    if p[pygame.K_UP]:
        return 2
    if p[pygame.K_DOWN]:
        return 5
    if p[pygame.K_k]:
        return 0
    return 0

# ============================================================
# 3. Sampling
# ============================================================
frames_rgb, states = [], []
t0 = time.time()
for t in range(args.steps):
    rgb = env_rgb.render()
    frames_rgb.append(rgb.copy())
    states.append(rgb_state(rgb))

    act = poll() if args.manual else env_rgb.action_space.sample()
    if act is None:
        print('[Quit]')
        break
    if act >= env_rgb.action_space.n:
        print(f'[Warn] action {act} out of range → NO-OP')
        act = 0

    _, _, term, trunc, _ = env_rgb.step(act)
    env_vis.step(act)
    if term or trunc:
        env_rgb.reset()
        env_vis.reset()
        dummy.reset_frame_stack()

    if (t + 1) % args.fps_print == 0:
        fps = (t + 1) / (time.time() - t0 + 1e-9)
        print(f'[Progress] {t + 1}/{args.steps}  {fps:6.1f} FPS')

env_rgb.close()
env_vis.close()
if args.manual:
    pygame.quit()

T = len(frames_rgb)
print('[Info] Sampling finished:', T, 'frames')

# ============================================================
# 4. Σφ  (sum over k options)
# ============================================================
with torch.no_grad():
    vps_sum = vps_net(torch.cat(states).to(dev)).cpu().numpy().sum(1)

# ============================================================
# 5. Peak detection
# ============================================================
win = max(1, args.win)
peaks = []
for s in range(0, T, win):
    seg = vps_sum[s : s + win]
    i = s + int(seg.argmax())
    peaks.append((i, float(vps_sum[i])))

# ============================================================
# 6. Plotting
# ============================================================
fig = plt.figure(figsize=(10, 4))
ax = fig.add_subplot(1, 1, 1)
ax.plot(vps_sum, color='#CC3333')
ax.set_ylim(0, args.ylim)
ax.set_xlabel('Steps')
ax.set_ylabel('Total VPS of 8 Options')
ax.set_title(f'{ENV_ID}')

overflow = []
for n, (i, v) in enumerate(peaks):
    if v <= args.ylim:
        ax.scatter(i, v, c='#FF9900')
        ax.text(i, v, f'{n}', fontsize=12, ha='right', va='bottom')
    else:
        overflow.append((n, i, v))

# Inset for overflow peaks
if overflow:
    ax2 = fig.add_axes([0.72, 0.6, 0.25, 0.2])
    ax2.plot(vps_sum, color='#CC3333')
    ax2.set_title('Overflow Peaks', fontsize=10)
    for n, i, v in overflow:
        ax2.scatter(i, v, c='#FF9900')
        ax2.text(i, v, f'{n}', fontsize=10, ha='right', va='bottom')

OUT_DIR.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(OUT_DIR / 'vps_sum_curve.png')
plt.show()

# ============================================================
# 7. Save 4 consecutive frames around each peak
# ============================================================
print('[Info] Saving peak frames …')
for n, (i, score) in enumerate(peaks):
    for off in range(4):
        if i + off >= T:
            break
        fn = PEAK_DIR / f'peak{n:02d}_frame{off}_{score:+.4f}.png'
        imageio.imwrite(fn, frames_rgb[i + off])
print('[✓] Done  →', PEAK_DIR.resolve())
