#!/usr/bin/env python3
"""
Record HRL system frames to PNG images.
"""
import torch
import gymnasium as gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
import os
import sys
from datetime import datetime
from PIL import Image

# Add current directory to path to import local modules
sys.path.append('.')
try:
    from agent import *
    from envs import *
except ImportError:
    print("Warning: Could not import agent/envs modules. Model loading may fail.")

def get_frame_from_env(env):
    """Get the current frame from environment"""
    try:
        frame = env.render()
        if frame is not None:
            return frame
    except:
        pass
    return np.zeros((480, 640, 3), dtype=np.uint8)

# def record_hrl_frames(continuous_model_path, ppo_model_path, env_name="Ant-v5", max_steps=1000):
#     """Record HRL system frames to PNG images"""
    
#     # Load continuous action policy
#     print(f"Loading continuous action policy from: {continuous_model_path}")
#     continuous_model = torch.load(continuous_model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
#     continuous_model.to(device)
    
#     # Load PPO model
#     print(f"Loading PPO model from: {ppo_model_path}")
#     ppo_model = PPO.load(ppo_model_path)
    
#     # Load VecNormalize stats with proper dummy environment
#     vec_normalize_path = os.path.join(os.path.dirname(ppo_model_path), "vec_normalize_stats.pkl")
#     if os.path.exists(vec_normalize_path):
#         print(f"Loading normalization stats from: {vec_normalize_path}")
#         # Create dummy vectorized environment for VecNormalize
#         from stable_baselines3.common.vec_env import DummyVecEnv
#         dummy_env = DummyVecEnv([lambda: gym.make(env_name)])
#         vec_normalize = VecNormalize.load(vec_normalize_path, venv=dummy_env)
#         dummy_env.close()
#     else:
#         vec_normalize = None
    
#     # Create base environment for offscreen rendering
#     base_env = gym.make(env_name, camera_id=0, render_mode="rgb_array")
    
#     # Create wrapper environment
#     class HRLWrapper(gym.Wrapper):
#         def __init__(self, env, continuous_policy, device, num_actions=64):
#             super().__init__(env)
#             self.continuous_policy = continuous_policy
#             self.device = device
#             self.action_space = gym.spaces.Discrete(num_actions)
#             self.obs = None
            
#         def reset(self, **kwargs):
#             obs, info = self.env.reset(**kwargs)
#             self.obs = obs
#             return obs, info
            
#         def step(self, discrete_action):
#             obs_tensor = torch.as_tensor(
#                 np.expand_dims(self.obs, axis=0), 
#                 dtype=torch.float32
#             ).to(self.device)
            
#             with torch.no_grad():
#                 continuous_action, *_ = self.continuous_policy.pi(
#                     obs_tensor,
#                     deterministic=True,
#                     manual_indices=torch.tensor([int(discrete_action)]).to(self.device)
#                 )
            
#             continuous_action_np = continuous_action.cpu().numpy()[0]
#             obs, reward, terminated, truncated, info = self.env.step(continuous_action_np)
#             self.obs = obs
#             return obs, reward, terminated, truncated, info
    
#     # Create wrapped environment
#     env = HRLWrapper(base_env, continuous_model, device)
    
#     # Create output directory
#     timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#     model_dir = os.path.dirname(os.path.dirname(ppo_model_path))
#     out_dir = os.path.join(model_dir, f"hrl_frames_{timestamp}")
#     os.makedirs(out_dir, exist_ok=True)
    
#     # Create metadata file
#     meta_path = os.path.join(out_dir, "metadata.txt")
#     with open(meta_path, "w") as mf:
#         mf.write(f"continuous_model: {continuous_model_path}\n")
#         mf.write(f"ppo_model: {ppo_model_path}\n")
#         mf.write(f"environment: {env_name}\n")
#         mf.write(f"timestamp: {timestamp}\n")
#         mf.write("notes: frames saved as frame_000000.png ...\n")
    
#     print(f"Saving frames to: {out_dir}")
    
#     with torch.no_grad():
#         obs, _ = env.reset()
#         done = False
#         step = 0
#         total_reward = 0
#         frame_idx = 0
        
#         while not done and step < max_steps:
#             # Normalize observation if needed
#             if vec_normalize is not None:
#                 obs_norm = vec_normalize.normalize_obs(obs)
#             else:
#                 obs_norm = obs
            
#             # Get discrete action from PPO model
#             discrete_action, _states = ppo_model.predict(obs_norm, deterministic=True)
#             # Get current frame
#             frame = get_frame_from_env(env)
            
#             # Convert frame to proper format
#             arr = np.asarray(frame)
#             if arr.ndim == 2:  # grayscale -> RGB
#                 arr = np.stack([arr]*3, axis=-1)
#             if arr.dtype != np.uint8:
#                 maxv = arr.max() if arr.size > 0 else 1.0
#                 if maxv <= 1.0:
#                     arr = (np.clip(arr, 0.0, 1.0) * 255.0).round().astype(np.uint8)
#                 else:
#                     arr = np.clip(arr, 0, 255).astype(np.uint8)
            
#             # Save frame every 10 steps to reduce file count
#             if step % 10 == 0:
#                 fname = os.path.join(out_dir, f"frame_{frame_idx:06d}.png")
#                 Image.fromarray(arr).save(fname, format="PNG")
#                 frame_idx += 1
            
#             # Step environment
#             obs, reward, terminated, truncated, info = env.step(discrete_action)
#             done = terminated or truncated
#             step += 1
#             total_reward += reward
            
#             if step % 100 == 0:
#                 print(f"Step: {step}, Discrete Action: {int(discrete_action)}, Reward: {total_reward:.2f}")
    
#     # Update metadata with final stats
#     with open(meta_path, "a") as mf:
#         mf.write(f"saved_frames: {frame_idx}\n")
#         mf.write(f"episode_length: {step}\n")
#         mf.write(f"episode_return: {total_reward}\n")
    
#     env.close()
#     print(f"Episode finished. Saved {frame_idx} frames to {out_dir}")
#     print(f"Episode length: {step}, Total reward: {total_reward:.2f}")


def record_hrl_frames(
    continuous_model_path: str,
    ppo_model_path: str,
    env_name: str = "ant_px",
    max_steps: int = 1000,
    seed: int = 0,
    num_actions: int = 128,
    save_every: int = 10,
):
    """
    Deterministic evaluation + frame capture for an HRL stack:
      - Discrete high-level policy (SB3 PPO)
      - Continuous low-level policy (PyTorch module with .pi(..., manual_indices=...))
      - Optional VecNormalize stats loaded frozen (no running updates)

    Notes (facts about this implementation):
      - Observations sent to PPO are normalized with the frozen VecNormalize stats if present.
      - The continuous policy receives raw (unnormalized) observations, matching the common
        training topology where VecNormalize wraps *outside* the HRL env wrapper.
      - Rewards summed as raw env rewards; normalized rewards are additionally computed
        for comparison if VecNormalize stats are available.
      - Both policies are put in eval mode; RNGs seeded for reproducibility.
    """
    import os
    import sys
    from datetime import datetime
    import numpy as np
    from PIL import Image
    import torch
    import gymnasium as gym
    from stable_baselines3 import PPO
    from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv

    # ---- Utility: robust frame capture --------------------------------------------------------
    def _get_frame_from_env(e):
        try:
            f = e.render()
            if f is not None:
                return f
        except Exception:
            pass
        return np.zeros((480, 640, 3), dtype=np.uint8)

    # ---- Load models --------------------------------------------------------------------------
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Continuous low-level policy (PyTorch)
    continuous_model = torch.load(
        continuous_model_path, map_location=device
    )
    continuous_model.to(device)
    continuous_model.eval()  # disable dropout/batchnorm, etc.

    # High-level PPO policy (SB3)
    ppo_model: PPO = PPO.load(ppo_model_path, device=device)
    # Ensure eval mode for SB3 policy network
    try:
        ppo_model.policy.set_training_mode(False)
    except Exception:
        pass  # compatible with older SB3 versions

    # ---- VecNormalize stats (frozen) ----------------------------------------------------------
    vec_normalize = None
    vec_normalize_path = os.path.join(os.path.dirname(ppo_model_path), "vec_normalize_stats.pkl")
    if os.path.exists(vec_normalize_path):
        # Create a dummy env with the same observation space to attach stats
        dummy_env = DummyVecEnv([lambda: gym.make(env_name)])
        vec_normalize = VecNormalize.load(vec_normalize_path, venv=dummy_env)
        vec_normalize.training = False      # freeze running stats
        vec_normalize.norm_reward = False   # keep raw reward when comparing to true env return
        dummy_env.close()

    # ---- Base env (off-screen rendering) ------------------------------------------------------
    base_env = gym.make(env_name, camera_id=0, render_mode="rgb_array")

    # ---- HRL wrapper: maps discrete action -> continuous via low-level policy -----------------
    class HRLWrapper(gym.Wrapper):
        """
        The wrapper exposes a Discrete(num_actions) action space to PPO.
        Internally, it maps the discrete index to a continuous action through
        the provided low-level policy.
        """
        def __init__(self, env, continuous_policy, device, num_actions=64):
            super().__init__(env)
            self.continuous_policy = continuous_policy
            self.device = device
            self.action_space = gym.spaces.Discrete(num_actions)
            self._last_obs = None

        def reset(self, **kwargs):
            obs, info = self.env.reset(**kwargs)
            self._last_obs = obs
            return obs, info

        def step(self, discrete_action):
            # low-level policy consumes RAW observations (matching typical training topology)
            obs_tensor = torch.as_tensor(
                np.expand_dims(self._last_obs, axis=0), dtype=torch.float32, device=self.device
            )
            with torch.no_grad():
                continuous_action, *_ = self.continuous_policy.pi(
                    obs_tensor,
                    deterministic=True,
                    manual_indices=torch.tensor([int(discrete_action)], device=self.device),
                )

            cont_act = continuous_action.detach().cpu().numpy()[0]
            obs, reward, terminated, truncated, info = self.env.step(cont_act)
            self._last_obs = obs
            return obs, reward, terminated, truncated, info

    env = HRLWrapper(base_env, continuous_model, device, num_actions=num_actions)

    # ---- Reproducibility ----------------------------------------------------------------------
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    try:
        env.reset(seed=seed)
    except TypeError:
        # for older gym APIs
        pass

    # ---- Output directory and metadata --------------------------------------------------------
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_dir = os.path.dirname(os.path.dirname(ppo_model_path))
    out_dir = os.path.join(model_dir, f"hrl_frames_{timestamp}")
    os.makedirs(out_dir, exist_ok=True)

    meta_path = os.path.join(out_dir, "metadata.txt")
    with open(meta_path, "w") as mf:
        mf.write(f"continuous_model: {continuous_model_path}\n")
        mf.write(f"ppo_model: {ppo_model_path}\n")
        mf.write(f"vecnormalize_stats: {vec_normalize_path if os.path.exists(vec_normalize_path) else 'None'}\n")
        mf.write(f"environment: {env_name}\n")
        mf.write(f"seed: {seed}\n")
        mf.write(f"timestamp: {timestamp}\n")
        mf.write("notes: frames saved as frame_000000.png, every N steps defined by save_every\n")

    print(f"[INFO] Saving frames to: {out_dir}")

    # ---- Rollout ------------------------------------------------------------------------------
    with torch.no_grad():
        obs, _ = env.reset(seed=seed)
        done = False
        step = 0
        frame_idx = 0

        total_reward_raw = 0.0
        total_reward_norm = 0.0
        have_norm = vec_normalize is not None

        while not done and step < max_steps:
            # Normalize observation for PPO if stats are available
            if have_norm:
                obs_in = vec_normalize.normalize_obs(obs)
            else:
                obs_in = obs

            # Discrete action from PPO
            discrete_action, _ = ppo_model.predict(obs_in, deterministic=False)

            # Save frame periodically
            if step % save_every == 0:
                frame = _get_frame_from_env(env)
                arr = np.asarray(frame)
                if arr.ndim == 2:  # grayscale -> RGB
                    arr = np.stack([arr] * 3, axis=-1)
                if arr.dtype != np.uint8:
                    maxv = arr.max() if arr.size > 0 else 1.0
                    if maxv <= 1.0:
                        arr = (np.clip(arr, 0.0, 1.0) * 255.0).round().astype(np.uint8)
                    else:
                        arr = np.clip(arr, 0, 255).astype(np.uint8)
                Image.fromarray(arr).save(os.path.join(out_dir, f"frame_{frame_idx:06d}.png"), format="PNG")
                frame_idx += 1

            # Env step through HRL wrapper
            obs, reward, terminated, truncated, info = env.step(discrete_action)
            done = bool(terminated or truncated)
            step += 1

            # Accumulate raw reward
            total_reward_raw += float(reward)

            # Also compute normalized reward (frozen stats) if available
            if have_norm:
                r_norm = float(vec_normalize.normalize_reward(np.array([reward], dtype=np.float32))[0])
                total_reward_norm += r_norm

            if step % 100 == 0:
                if have_norm:
                    print(f"[INFO] step={step}  act={int(discrete_action)}  R_raw={total_reward_raw:.2f}  R_norm={total_reward_norm:.2f}")
                else:
                    print(f"[INFO] step={step}  act={int(discrete_action)}  R_raw={total_reward_raw:.2f}")

    # ---- Persist stats ------------------------------------------------------------------------
    with open(meta_path, "a") as mf:
        mf.write(f"saved_frames: {frame_idx}\n")
        mf.write(f"episode_length: {step}\n")
        mf.write(f"episode_return_raw: {total_reward_raw}\n")
        if vec_normalize is not None:
            mf.write(f"episode_return_norm: {total_reward_norm}\n")

    env.close()
    print(f"[DONE] Episode finished. Saved {frame_idx} frames to {out_dir}")
    if vec_normalize is not None:
        print(f"[DONE] Episode length: {step}, Return(raw)={total_reward_raw:.2f}, Return(norm)={total_reward_norm:.2f}")
    else:
        print(f"[DONE] Episode length: {step}, Return(raw)={total_reward_raw:.2f}")



if __name__ == "__main__":
    
    # Construct paths
    ppo_model_path = 'hrl/models/ant_px_upf_ant_mpf_alpha_0.10_nc_128_gamma_0.9_ep_200_epochs_100_30mil_hard_ppo/model.zip'
    continuous_model_path = 'hrl/models/upf_ant_mpf_alpha_0.10_nc_128_gamma_0.9_ep_200_epochs_100/upf_ant_mpf_alpha_0.10_nc_128_gamma_0.9_ep_200_epochs_100_s0/pyt_save/model.pt'
    
    # Check if files exist
    if not os.path.exists(ppo_model_path):
        print(f"Error: PPO model not found at {ppo_model_path}")
        exit(1)
    if not os.path.exists(continuous_model_path):
        print(f"Error: Continuous model not found at {continuous_model_path}")
        exit(1)
    
    record_hrl_frames(continuous_model_path, ppo_model_path, 'ant_px', 1000)
    