# poly_ppo_ucb.py

import os
import math
import json
import time
import random
import argparse
from typing import Dict, Any, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from collections import Counter
import hashlib


try:
    import wandb 
except Exception:
    wandb = None

from ..utils import (
    BabyAI_BC,
    prepare_obs,
    cluster_seq_diversity,
    HIDDEN_DIM,
)

from ..train_utils import setup_environment as setup_environment_multiseed
from fractal_sampling import generate_dataset as generate_fractal_dataset
from collections import defaultdict
from fractal_sampling import obs_to_key


FINE_TUNED_MODEL_TMPL = "{game}/pretrained_models_multiseed/polychrome_ppo_{game}.pt"
EVAL_LOG_JSON_TMPL = "{game}/pretrained_models_multiseed/polychrome_ppo_eval_{game}.json"


GAME_CODE_TO_INFO = {
    "open": {"use_text": True},
    "pickup": {"use_text": True},
    "goto": {"use_text": True},
    "unlock": {"use_text": True},
    "synthseq": {"use_text": True},
    "bosslevel": {"use_text": True},
}





def cat_kl_from_logits(p_logits: torch.Tensor, q_logits: torch.Tensor) -> torch.Tensor:
    """KL(softmax(p_logits) || softmax(q_logits)) per row."""
    p_log = torch.log_softmax(p_logits, dim=-1)
    q_log = torch.log_softmax(q_logits, dim=-1)
    p = p_log.exp()
    return (p * (p_log - q_log)).sum(-1)

def cat_entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    logp = torch.log_softmax(logits, dim=-1)
    p = logp.exp()
    return -(p * logp).sum(-1)

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def _nan_to_none(x):
    if x is None: return None
    try:
        xf = float(x)
    except Exception:
        return None
    if math.isnan(xf) or math.isinf(xf): return None
    return xf


# Critic Network

class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int = HIDDEN_DIM, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1) 


# PPO Agent 

class PPOAgent:
    def __init__(
        self,
        n_actions: int,
        device: torch.device,
        use_text: bool,
        actor_lr: float,
        critic_lr: float,
        ent_coef: float,
        kl_coef: float,
        clip_eps: float,
        max_grad_norm: float,
        pretrained_weights: Optional[str] = None,
        ucb_coef: float = 0.0,
    ):
        self.device = device
        self.use_text = use_text
        self.ent_coef = ent_coef
        self.kl_coef = kl_coef
        self.clip_eps = clip_eps
        self.max_grad_norm = max_grad_norm
        self.ucb_coef = ucb_coef

        vocab_size = 200
        state = None
        if pretrained_weights and os.path.isfile(pretrained_weights):
            state = torch.load(pretrained_weights, map_location=self.device)
            print(f"[PPO] Loaded pretrained actor weights from: {pretrained_weights}")
            if self.use_text and "tok_emb.weight" in state:
                vocab_size = int(state["tok_emb.weight"].shape[0])
            elif self.use_text:
                print(f"[PPO] WARNING: No tok_emb.weight in pretrained weights, using vocab_size=200!!!")
        else:
            print(f"[PPO] WARNING: No pretrained weights at: {pretrained_weights} (actor from scratch)")

        self.actor = BabyAI_BC(n_actions=n_actions, use_text=self.use_text, vocab_size=vocab_size).to(self.device)
        if state is not None:
            self.actor.load_state_dict(state, strict=True)

        self.ref_actor = BabyAI_BC(n_actions=n_actions, use_text=self.use_text, vocab_size=vocab_size).to(self.device)
        self.ref_actor.load_state_dict(self.actor.state_dict())
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False

        self.critic = ValueNetwork(HIDDEN_DIM, HIDDEN_DIM).to(device)

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=critic_lr)

    @torch.no_grad()
    def select_action(self, obs: Dict[str, Any], temperature: float = 1.0) -> Tuple[int, float, float]:
        batch = prepare_obs(obs, device=self.device, use_text=self.use_text)
        logits = self.actor(batch) / max(1e-6, float(temperature))
        dist = torch.distributions.Categorical(logits=logits)
        a = dist.sample()
        logp = dist.log_prob(a)
        return int(a.item()), float(logp.item()), float(dist.entropy().item())

    @torch.no_grad()
    def value(self, obs: Dict[str, Any]) -> float:
        batch = prepare_obs(obs, device=self.device, use_text=self.use_text)
        feats = self.actor.features(batch)          
        v = self.critic(feats)                     
        return float(v.squeeze(0).item())

    def _batchify_obs(self, obs_list: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        imgs, dirs, toks = [], [], []
        
        expected_shape = None
        for i, obs in enumerate(obs_list):
            prepped = prepare_obs(obs, device=self.device, use_text=self.use_text)
            img_tensor = prepped["image"]

            if expected_shape is None:
                expected_shape = img_tensor.shape
            
            if img_tensor.shape != expected_shape:
                raise RuntimeError(
                    f"Mismatched image tensor shape at index {i} in the rollout."
                    f"\n  - Expected shape: {expected_shape}"
                    f"\n  - Got shape:      {img_tensor.shape}"
                    f"\nThis usually means a wrapper is creating a bad observation dictionary."
                    f"\nProblematic raw observation: {obs}"
                )

            imgs.append(img_tensor)
            dirs.append(prepped["direction"])
            if self.use_text and "mission_tokens" in prepped:
                toks.append(prepped["mission_tokens"])

        batch = {
            "image": torch.cat(imgs, dim=0),
            "direction": torch.cat(dirs, dim=0),
        }
        if self.use_text:
            if len(toks) == 0:
                print(f"Warning: No mission tokens provided for batch {batch}! Even though use_text is True!")
                batch["mission_tokens"] = torch.zeros((batch["image"].shape[0], 1), dtype=torch.long, device=self.device)
            else:
                batch["mission_tokens"] = torch.cat(toks, dim=0)
        return batch

    def _state_key(self, obs: Dict[str, Any]) -> Tuple[str, int, str]:
        """
        Compact, stable key for counting N(s,a) from raw obs dicts.
        Uses mission text, direction (if present), and a short hash of the image bytes.
        """
        mission = obs.get("mission", "")
        direction = int(obs.get("direction", -1)) if isinstance(obs.get("direction"), (int,)) else -1
        img = obs.get("image", None)
        try:
            h = hashlib.sha1(img.tobytes()).hexdigest()[:8] if hasattr(img, "tobytes") else "noimg"
        except Exception:
            h = "noimg"
        return (mission, direction, h)

    def compute_gae(self, rewards, values, dones, next_value, gamma, gae_lambda):
        advantages = torch.zeros_like(rewards).to(self.device)
        last_gae_lam = 0
        values = torch.cat((values, next_value.view(1)))
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * values[t+1] * (1 - dones[t]) - values[t]
            advantages[t] = delta + gae_lambda * gamma * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t]
        return advantages, advantages + values[:-1]

    def update_from_dataset(self,
                            on_policy_dataset: Dict,
                            env_grid_size: tuple,
                            ppo_epochs: int,
                            minibatch_size: int,
                            gamma: float,
                            gae_lambda: float,
                            value_loss_coef: float,
                            polychrome_window: int = 5):

        self.actor.train()
        self.critic.train()
        self.ref_actor.eval()

        all_obs, all_actions, all_old_logps, all_rewards, all_dones = [], [], [], [], []
        traj_indices = []
        for trajectories in on_policy_dataset.values():
            for traj in trajectories:
                start_idx = len(all_obs)
                for step in traj:
                    all_obs.append(step['observation'])
                    all_actions.append(step['action'])
                    all_old_logps.append(step['log_prob'])
                    all_rewards.append(step['reward'])
                    all_dones.append(step['done'])
                traj_indices.append((start_idx, len(all_obs)))
        
        if not all_obs: return 0.0, 0.0, 0.0

        sa_counts = Counter()
        for ob, a in zip(all_obs, all_actions):
            k = (self._state_key(ob), int(a))
            sa_counts[k] += 1

        obs_batch_dict = self._batchify_obs(all_obs)
        with torch.no_grad():
            features_batch = self.actor.features(obs_batch_dict)
            values_batch = self.critic(features_batch).squeeze(-1)
        
        rewards_batch = torch.tensor(all_rewards, device=self.device, dtype=torch.float)
        dones_batch = torch.tensor(all_dones, device=self.device, dtype=torch.float)
        
        advantages_list, returns_list = [], []
        for start, end in traj_indices:
            adv, ret = self.compute_gae(
                rewards_batch[start:end], values_batch[start:end], dones_batch[start:end],
                torch.tensor(0.0, device=self.device), gamma, gae_lambda
            )
            advantages_list.append(adv)
            returns_list.append(ret)
        
        advantages_batch = torch.cat(advantages_list)
        returns_batch = torch.cat(returns_list)
        actions_batch = torch.tensor(all_actions, device=self.device, dtype=torch.long)
        log_probs_old_batch = torch.tensor(all_old_logps, device=self.device)
        advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)


        aug_obs, aug_actions, aug_logps, aug_adv, aug_returns = [], [], [], [], []

        avg_group_diversity = []
        for vine_key_tuple, trajectories in on_policy_dataset.items():
            if len(trajectories) < 4: continue
            groups = [random.sample(trajectories, 4) for _ in range(4)]
            group_values = []
            for group in groups:
                reward = sum(sum(s['reward'] for s in t) for t in group)
                diversity = cluster_seq_diversity([[step['env_state'][0] for step in t] for t in group], grid_size=env_grid_size)
                group_values.append((reward / len(group)) * diversity)
                avg_group_diversity.append(diversity)
            
            if not group_values: continue
            baseline = sum(group_values) / len(group_values)

            for group, gval in zip(groups, group_values):
                advantage = float(gval - baseline)
                for traj in group:
                    start_idx = 0
                    mission0 = traj[start_idx]['observation']['mission']
                    assert mission0 is not None, "Mission should be in the observation"
                    if mission0 != vine_key_tuple[1]:
                        import ipdb
                        ipdb.set_trace()
                        assert mission0 == vine_key_tuple[1], "Mission should be the same"

                    
                    for t in range(start_idx, min(start_idx + polychrome_window, len(traj))):
                        step = traj[t]
                        mission = step['observation']['mission']
                        assert mission is not None, "Mission should be in the observation"
                        assert mission == mission0, "Mission should be the same"                        
                        aug_obs.append(step['observation'])
                        aug_actions.append(step['action'])
                        aug_logps.append(step['log_prob'])
                        aug_adv.append(advantage)
                        with torch.no_grad():
                            obs_b = self._batchify_obs([step['observation']])
                            feat = self.actor.features(obs_b)
                            v = self.critic(feat).item()
                            aug_returns.append(v)

        if aug_obs:
            aug_obs_batch_dict = self._batchify_obs(aug_obs)
            for k in obs_batch_dict:
                obs_batch_dict[k] = torch.cat([obs_batch_dict[k], aug_obs_batch_dict[k]], dim=0)
            
            actions_batch = torch.cat([actions_batch, torch.tensor(aug_actions, device=self.device, dtype=torch.long)])
            log_probs_old_batch = torch.cat([log_probs_old_batch, torch.tensor(aug_logps, device=self.device)])
            returns_batch = torch.cat([returns_batch, torch.tensor(aug_returns, device=self.device)])
            aug_adv_tensor = torch.tensor(aug_adv, device=self.device, dtype=torch.float)
            aug_adv_tensor = (aug_adv_tensor - aug_adv_tensor.mean()) / (aug_adv_tensor.std() + 1e-8)
            advantages_batch = torch.cat([advantages_batch, aug_adv_tensor])

        raw_obs_full: List[Dict[str, Any]] = list(all_obs)
        if aug_obs:
            raw_obs_full.extend(aug_obs)

        n = actions_batch.size(0)
        assert len(raw_obs_full) == n, (
            f"raw_obs_full ({len(raw_obs_full)}) must match actions_batch ({n}) after augmentation"
        )


        actor_losses, critic_losses, kl_hist = [], [], []
        n = actions_batch.size(0)
        for _ in range(ppo_epochs):
            indices = torch.randperm(n)
            for start in range(0, n, minibatch_size):
                mb_idx = indices[start:start+minibatch_size]
                mb_obs = {k: v[mb_idx] for k, v in obs_batch_dict.items()}
                mb_actions = actions_batch[mb_idx]
                mb_old_logp = log_probs_old_batch[mb_idx]
                mb_returns = returns_batch[mb_idx]
                mb_adv = advantages_batch[mb_idx]

                if self.ucb_coef != 0.0:
                    mb_idx_list = mb_idx.tolist()
                    bonus_vals = []
                    for i in mb_idx_list:
                        ob_i = raw_obs_full[i] 
                        a_i = int(actions_batch[i].item())
                        N = sa_counts[(self._state_key(ob_i), a_i)] 
                        b = self.ucb_coef * min(1.0, (1.0 / max(1, N)) ** 0.5)
                        bonus_vals.append(b)
                    mb_bonus = torch.tensor(bonus_vals, device=self.device, dtype=torch.float)
                    mb_adv_aug = mb_adv + mb_bonus
                else:
                    mb_adv_aug = mb_adv


                mb_feats = self.actor.features(mb_obs)
                v = self.critic(mb_feats.detach())  
                v_loss = value_loss_coef * nn.functional.mse_loss(v, mb_returns)
                self.critic_opt.zero_grad(set_to_none=True)
                v_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_opt.step()

                logits = self.actor(mb_obs)
                dist = torch.distributions.Categorical(logits=logits)
                logp = dist.log_prob(mb_actions)
                ratio = torch.exp(logp - mb_old_logp)
                ent = dist.entropy().mean()
                with torch.no_grad():
                    ref_logits = self.ref_actor(mb_obs)
                kl = cat_kl_from_logits(logits, ref_logits).mean()
                surr1 = ratio * mb_adv_aug
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * mb_adv_aug
                a_loss = -torch.min(surr1, surr2).mean() - self.ent_coef * ent + self.kl_coef * kl
                self.actor_opt.zero_grad(set_to_none=True)
                a_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_opt.step()

                actor_losses.append(a_loss.item())
                critic_losses.append(v_loss.item())
                kl_hist.append(kl.item())
        
        avg_actor = float(np.mean(actor_losses)) if actor_losses else 0.0
        avg_critic = float(np.mean(critic_losses)) if critic_losses else 0.0
        avg_kl = float(np.mean(kl_hist)) if kl_hist else 0.0
        print(f"[Poly-PPO] Update -> actor: {avg_actor:.4f} | critic: {avg_critic:.4f} | KL(ref): {avg_kl:.5f}")
        return avg_actor, avg_critic, avg_kl, avg_group_diversity
        
        
# Evaluation
@torch.no_grad()
def evaluate_policy_simple(
    actor: BabyAI_BC,
    env,
    device: torch.device,
    use_text: bool,
    num_episodes: int,
    max_steps: Optional[int] = None,
    temperature: float = 1.0,
) -> Dict[str, float]:
    rewards, successes, entropies = [], [], []
    all_trajs: List[List[Tuple[int, int]]] = []

    def _get_xy(e) -> Tuple[int, int]:
        try:
            p = getattr(e.unwrapped, "agent_pos", None)
            if p is not None: return int(p[0]), int(p[1])
        except Exception: pass
        p = getattr(e, "agent_pos", None)
        return (int(p[0]), int(p[1])) if p is not None else (-1, -1)

    def _get_grid_size(e) -> Tuple[int, int]:
        try:
            g = getattr(e.unwrapped, "grid", None)
            if g is not None: return int(g.width), int(g.height)
        except Exception: pass
        return (19, 19)

    grid_size = None

    for _ in range(num_episodes):
        obs, info = env.reset()
        done = False
        ep_r = 0.0
        steps = 0
        ep_ent = []

        traj = []
        if grid_size is None: grid_size = _get_grid_size(env)
        traj.append(_get_xy(env))

        while not done:
            batch = prepare_obs(obs, device=device, use_text=use_text)
            logits = actor(batch) / max(1e-6, float(temperature))
            dist = torch.distributions.Categorical(logits=logits)
            a = dist.sample()
            ep_ent.append(float(dist.entropy().item()))

            obs, reward, terminated, truncated, info = env.step(int(a.item()))
            shaped_r = 0.0
            if terminated: shaped_r = 1.0 - 0.5 * (steps / max_steps)

            done = bool(terminated or truncated)
            ep_r += float(shaped_r)
            steps += 1
            traj.append(_get_xy(env))
            if max_steps is not None and steps >= max_steps:
                break

        rewards.append(ep_r)
        successes.append(success_from_info_or_reward(terminated, truncated, info, ep_r))
        entropies.append(float(np.mean(ep_ent)) if ep_ent else 0.0)
        all_trajs.append(traj)

    diversity, room_sequences = cluster_seq_diversity(all_trajs, grid_size=grid_size, return_sequences=True, grid_dims=(3, 3))
    print(f"room_sequences: {room_sequences}")

    return {
        "avg_reward": float(np.mean(rewards)) if rewards else 0.0,
        "success_rate": float(np.mean(successes)) if successes else 0.0,
        "avg_entropy": float(np.mean(entropies)) if entropies else 0.0,
        "diversity": diversity,
    }

# Argparse
def parse_args():
    p = argparse.ArgumentParser(description="Polychromic PPO for BabyAI.")
    p.add_argument("--game_code", type=str, default="open", help="open|pickup|goto|unlock")
    p.add_argument("--seed", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--ppo_epochs", type=int, default=2)
    p.add_argument("--minibatch_size", type=int, default=64)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--clip_epsilon", type=float, default=0.2)
    p.add_argument("--actor_lr", type=float, default=1e-5)
    p.add_argument("--critic_lr", type=float, default=1e-4)
    p.add_argument("--value_loss_coef", type=float, default=0.5)
    p.add_argument("--entropy_coef", type=float, default=0.0)
    p.add_argument("--kl_coef", type=float, default=0.05)
    p.add_argument("--max_grad_norm", type=float, default=0.5)
    p.add_argument("--ucb_coef", type=float, default=0.0, help="λ_ucb for per-(s,a) bonus: λ_ucb * min(1, sqrt(1/N(s,a))). Set 0 to disable.")
    p.add_argument("--num_vines_at_state", type=int, default=8)
    p.add_argument("--polychrome_window", type=int, default=5)
    p.add_argument("--num_levels", type=int, default=3) # num_levels is the number of timesteps to sample from at each fractal sampling step. 
    p.add_argument("--num_training_epochs", type=int, default=1500)
    p.add_argument("--max_steps_per_episode", type=int, default=100)
    p.add_argument("--eval_interval", type=int, default=100)
    p.add_argument("--heavy_eval_interval", type=int, default=200)
    p.add_argument("--num_eval_episodes", type=int, default=10)
    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--pretrain_dir", type=str, default="pretrained_models")
    p.add_argument("--pretrain_name_tmpl", type=str, default="pretrain_{game}.pt")
    p.add_argument("--save_path", type=str, default="")
    p.add_argument("--eval_log_path", type=str, default="")
    p.add_argument("--use_text", type=str, default="true", choices=["auto", "true", "false"])
    return p.parse_args()

def _env_bool(key: str, default: bool = False) -> bool:
    v = os.environ.get(key)
    if v is None: return default
    return str(v).strip().lower() in {"1", "true", "yes", "on"}

def start_wandb_auto(game_code: str, cfg: Dict[str, Any]):
    if wandb is None:
        print("[W&B] Not installed; proceeding without logging.")
        return None
    if _env_bool("WANDB_DISABLED", False):
        print("[W&B] Disabled via WANDB_DISABLED.")
        return None
    mode = os.environ.get("WANDB_MODE", "online")
    project = os.environ.get("WANDB_PROJECT", "polychromic_ppo")
    run_name = os.environ.get("WANDB_RUN_NAME", f"poly_ppo_{game_code}_{int(time.time())}")
    run = wandb.init(project=project, name=run_name, config=cfg, mode=mode, reinit=True)
    if run is not None: print(f"[W&B] Initialized in {mode} mode → project={project} run={run.name}")
    return run

def wb_log(run, payload: Dict[str, Any]):
    if run is None: return
    clean = {k: v for k, v in payload.items() if not (isinstance(v, float) and (math.isnan(v) or math.isinf(v)))}
    wandb.log(clean)


# Main
def main():
    args = parse_args()
    device = torch.device(args.device)
    torch.manual_seed(args.seed % (2**32 - 1))
    np.random.seed(args.seed % (2**32 - 1))

    config_filename = f"{args.game_code}.yaml"
    config_path = os.path.join(args.config_dir, config_filename)
    print(f"--- Loading configuration from: {config_path} ---")
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        return 

    if args.use_text == "auto":
        use_text = GAME_CODE_TO_INFO.get(args.game_code, {"use_text": False})["use_text"]
    else:
        use_text = (args.use_text == "true")

    assert use_text, "Text must be used"

    cfg_for_wandb = {**vars(args), "use_text_resolved": use_text}
    wandb_run = start_wandb_auto(args.game_code, cfg_for_wandb)

    print("--- Creating a pool of 20 train and 20 eval environments... ---")
    train_env_pool = []
    eval_env_pool = []
    eval_heldout_env_pool = []

    eval_logs: List[Dict[str, Any]] = []   
    heldout_logs: List[Dict[str, Any]] = []   

    train_seeds = config['train_seeds']
    evaluate_seeds = config['evaluate_seeds']
    multiseed_seeds = train_seeds
    
    for s in multiseed_seeds:
        train_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
        eval_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))

    heldout_seeds = evaluate_seeds[-100:]
    for s in heldout_seeds:
        eval_heldout_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
    
    print(f"--- Environment pools created. Total environments: {len(train_env_pool)} train, {len(eval_env_pool)} eval. ---")
    
    def _checkpoint_dir_from(save_path: str) -> str:
        base = os.path.dirname(save_path)
        ckpt_dir = os.path.join(base, "checkpoints")
        os.makedirs(ckpt_dir, exist_ok=True)
        return ckpt_dir

    def _save_checkpoint(epoch: int, path: str, agent: PPOAgent, args, extra: Dict[str, Any] = None):
        ckpt = {
            "epoch": epoch,
            "args": vars(args),
            "actor_state": agent.actor.state_dict(),
            "critic_state": agent.critic.state_dict(),
            "actor_opt_state": agent.actor_opt.state_dict(),
            "critic_opt_state": agent.critic_opt.state_dict(),
            "rng_state": {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "torch": torch.get_rng_state(),
                "torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
            },
        }
        if extra:
            ckpt.update(extra)
        torch.save(ckpt, path)

    pretrain_path = os.path.join(args.pretrain_dir, args.pretrain_name_tmpl.format(game=args.game_code))
    save_path = args.save_path or FINE_TUNED_MODEL_TMPL.format(game=args.game_code)
    ckpt_dir = _checkpoint_dir_from(save_path)

    eval_log_path = args.eval_log_path or EVAL_LOG_JSON_TMPL.format(game=args.game_code)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    os.makedirs(os.path.dirname(eval_log_path), exist_ok=True)

    agent = PPOAgent(
        n_actions=train_env_pool[0].action_space.n, device=device, use_text=use_text,
        actor_lr=args.actor_lr, critic_lr=args.critic_lr, ent_coef=args.entropy_coef,
        kl_coef=args.kl_coef, clip_eps=args.clip_epsilon, max_grad_norm=args.max_grad_norm,
        pretrained_weights=pretrain_path, ucb_coef=args.ucb_coef
    )

    eval_metrics_list = []
    for i, env in enumerate(eval_env_pool):
        metrics = evaluate_policy_simple(
            agent.actor, env, device, use_text, args.num_eval_episodes,
            args.max_steps_per_episode, args.temperature
        )
        eval_metrics_list.append(metrics)
        print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")
    
    avg_reward = np.mean([m['avg_reward'] for m in eval_metrics_list])
    avg_success = np.mean([m['success_rate'] for m in eval_metrics_list])
    avg_diversity = np.mean([m['diversity'] for m in eval_metrics_list])
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")
    wb_log(wandb_run, {"heavy_eval/avg_reward": avg_reward, "heavy_eval/success_rate": avg_success * 100.0, "heavy_eval/diversity": avg_diversity})
    
    wb_eval_payload = {}
    for i, m in enumerate(eval_metrics_list):
        wb_eval_payload[f"heavy_eval_env_{i+1}/avg_reward"] = m['avg_reward']
        wb_eval_payload[f"heavy_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
        wb_eval_payload[f"heavy_eval_env_{i+1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"heavy_eval/epoch": 0, **wb_eval_payload})
    epoch_logs = [{"heavy_eval/epoch": 0, "env_index": eval_env_pool.index(env), **metrics} for env, metrics in zip(eval_env_pool, eval_metrics_list)]
    eval_logs.extend(epoch_logs)


    print("\n--- Initial Evaluation of Pre-trained Policy (on held out 10 envs) ---")
    initial_heldout_metrics_list = []
    for i, env in enumerate(eval_heldout_env_pool):
        metrics = evaluate_policy_simple(
            agent.actor, env, device, use_text, args.num_eval_episodes,
            args.max_steps_per_episode, args.temperature
        )
        initial_heldout_metrics_list.append(metrics)
        print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")
    
    avg_reward = np.mean([m['avg_reward'] for m in initial_heldout_metrics_list])
    avg_success = np.mean([m['success_rate'] for m in initial_heldout_metrics_list])
    avg_diversity = np.mean([m['diversity'] for m in initial_heldout_metrics_list])
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")

    heldout_logs = [{"epoch": 0, "env_index": eval_heldout_env_pool.index(env), **metrics} for env, metrics in zip(eval_heldout_env_pool, initial_heldout_metrics_list)]
    
    wb_payload = {}
    for i, m in enumerate(initial_heldout_metrics_list):
        wb_payload[f"heldout_eval_env_{i+1}/avg_reward"] = m['avg_reward']
        wb_payload[f"heldout_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
        wb_payload[f"heldout_eval_env_{i+1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"eval/epoch": 0, **wb_payload})

    wb_payload = {
        "heldout_eval/avg_reward": avg_reward,
        "heldout_eval/success_rate": avg_success * 100.0,
        "heldout_eval/diversity": avg_diversity,
    }
    wb_log(wandb_run, wb_payload)

    
    for epoch in range(args.num_training_epochs):
        print(f"\n--- Epoch {epoch+1}/{args.num_training_epochs} ---")

        selected_train_envs = []
        selected_train_envs = random.sample(train_env_pool, 4)
        
        datasets = []
        for env in selected_train_envs:
            dataset = generate_fractal_dataset(
                policy=agent.actor, env=env, num_vines_at_state=args.num_vines_at_state,
                num_levels=args.num_levels, main_rollout_max_steps=args.max_steps_per_episode, device=device
            )
            datasets.append(dataset)

        
        combined_dataset = defaultdict(list)
        for d in datasets:
            for _, trajectories in d.items():
                for traj in trajectories:
                    if not traj:
                        continue
                    k2 = obs_to_key(traj[0]['observation'])  
                    combined_dataset[k2].append(traj)
        combined_dataset = dict(combined_dataset)

        for kx, trajs in combined_dataset.items():
            for i, t in enumerate(trajs):
                assert t and obs_to_key(t[0]['observation']) == kx, \
                    f"Key/content mismatch at traj #{i} for key mission={kx[1]!r}"
        
        print(f"Total size of combined dataset: {sum(len(v) for v in combined_dataset.values())} trajectories from 4 random envs")

        if combined_dataset:
            grid_size = (train_env_pool[0].unwrapped.width, train_env_pool[0].unwrapped.height)
            actor_loss, critic_loss, kl_div, group_div_list = agent.update_from_dataset(
                on_policy_dataset=combined_dataset, 
                env_grid_size=grid_size, ppo_epochs=args.ppo_epochs,
                minibatch_size=args.minibatch_size, gamma=args.gamma, gae_lambda=args.gae_lambda,
                value_loss_coef=args.value_loss_coef, polychrome_window=args.polychrome_window
            )
            if isinstance(group_div_list, (list, tuple)) and len(group_div_list) > 0:
                avg_group_diversity_val = float(np.mean(group_div_list))
                hist_payload = {"train/group_diversity_hist": wandb.Histogram(group_div_list)} if wandb is not None else {}
            else:
                avg_group_diversity_val = 0.0
                hist_payload = {}

            wb_log(wandb_run, {
                "train/epoch": epoch + 1,
                "train/actor_loss": actor_loss,
                "train/critic_loss": critic_loss,
                "train/kl_div": kl_div,
                "train/average_group_diversity": avg_group_diversity_val,
                **hist_payload
            })
        
        if (epoch + 1) % args.heavy_eval_interval == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch+1} (on all {len(eval_env_pool)} eval envs) ---")
            eval_metrics_list = []
            for i, env in enumerate(eval_env_pool):
                metrics = evaluate_policy_simple(
                    agent.actor, env, device, use_text, args.num_eval_episodes,
                    args.max_steps_per_episode, args.temperature
                )
                eval_metrics_list.append(metrics)
                print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")
            
            avg_reward = np.mean([m['avg_reward'] for m in eval_metrics_list])
            avg_success = np.mean([m['success_rate'] for m in eval_metrics_list])
            avg_diversity = np.mean([m['diversity'] for m in eval_metrics_list])
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")
            wb_log(wandb_run, {"heavy_eval/avg_reward": avg_reward, "heavy_eval/success_rate": avg_success * 100.0, "heavy_eval/diversity": avg_diversity})
            
            wb_eval_payload = {}
            for i, m in enumerate(eval_metrics_list):
                wb_eval_payload[f"heavy_eval_env_{i+1}/avg_reward"] = m['avg_reward']
                wb_eval_payload[f"heavy_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
                wb_eval_payload[f"heavy_eval_env_{i+1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heavy_eval/epoch": epoch + 1, **wb_eval_payload})
            
            epoch_logs = [{"heavy_eval/epoch": epoch + 1, "env_index": eval_env_pool.index(env), **metrics} for env, metrics in zip(eval_env_pool, eval_metrics_list)]
            eval_logs.extend(epoch_logs)

            print(f"\n--- Held-out Evaluation at Epoch {epoch+1} (on all {len(eval_heldout_env_pool)} held-out eval envs) ---")
            heldout_metrics_list = []
            for i, env in enumerate(eval_heldout_env_pool):
                metrics = evaluate_policy_simple(
                    agent.actor, env, device, use_text, args.num_eval_episodes,
                    args.max_steps_per_episode, args.temperature
                )
                heldout_metrics_list.append(metrics)
                print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")

            avg_reward = np.mean([m['avg_reward'] for m in heldout_metrics_list])
            avg_success = np.mean([m['success_rate'] for m in heldout_metrics_list])
            avg_diversity = np.mean([m['diversity'] for m in heldout_metrics_list])
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")  

            heldout_logs_epoch = [{"epoch": epoch + 1, "env_index": eval_heldout_env_pool.index(env), **metrics} for env, metrics in zip(eval_heldout_env_pool, heldout_metrics_list)]
            heldout_logs.extend(heldout_logs_epoch)

            wb_payload = {}
            for i, m in enumerate(heldout_metrics_list):
                wb_payload[f"heldout_eval_env_{i+1}/avg_reward"] = m['avg_reward']
                wb_payload[f"heldout_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
                wb_payload[f"heldout_eval_env_{i+1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heldout_eval/epoch": epoch + 1, **wb_payload})

            wb_payload = {
                "heldout_eval/avg_reward": avg_reward,
                "heldout_eval/success_rate": avg_success * 100.0,
                "heldout_eval/diversity": avg_diversity,
            }
            wb_log(wandb_run, wb_payload)
        if (epoch + 1) in {250, 500, 1000, 1200}:
            ckpt_path = os.path.join(ckpt_dir, f"poly_ppo_{args.game_code}_epoch_{epoch+1}.pt")
            _save_checkpoint(epoch + 1, ckpt_path, agent, args)
            print(f"[CKPT] Saved checkpoint → {ckpt_path}")
            if wandb_run is not None:
                try:
                    art = wandb.Artifact(f"poly-ppo-{args.game_code}-ckpt-ep{epoch+1}", type="model")
                    art.add_file(ckpt_path)
                    wandb.log_artifact(art)
                except Exception as e:
                    print(f"[W&B] Checkpoint artifact logging failed: {e}")


    print("\nTraining complete.")
    torch.save(agent.actor.state_dict(), save_path)
    print(f"Final policy saved to {save_path}")

    torch.save(agent.critic.state_dict(), os.path.join(ckpt_dir, f"poly_ppo_{args.game_code}_critic.pt"))
    print(f"Critic weights saved to: {os.path.join(ckpt_dir, f'poly_ppo_{args.game_code}_critic.pt')}")
    
    with open(eval_log_path, "w") as f:
        json.dump(eval_logs, f, indent=2)
    print(f"Eval logs saved to: {eval_log_path}")
    
    if wandb_run is not None:
        try:
            art = wandb.Artifact(f"poly-ppo-{args.game_code}-actor", type="model")
            art.add_file(save_path)
            art.add_file(eval_log_path)
            wandb.log_artifact(art)
        except Exception as e:
            print(f"[W&B] Artifact logging failed: {e}")
        wandb.finish()

    print("--- Closing all environments... ---")
    for env in train_env_pool:
        env.close()
    for env in eval_env_pool:
        env.close()
    for env in eval_heldout_env_pool:
        env.close()
    print("All environments closed. Exiting.")

if __name__ == "__main__":
    main()