# poly_ppo_ucb.py

import os
import math
import json
import time
import argparse
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict, Counter
import random
import hashlib

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F 

try:
    import wandb  
except Exception:
    wandb = None

from pretrain import MLPPolicy  
from ..utils import prepare_obs, HIDDEN_DIM, cluster_seq_diversity
from ..train_utils import setup_environment as setup_environment_multiseed

from fractal_sampling import generate_dataset as generate_fractal_dataset
from fractal_sampling import obs_to_key

FINE_TUNED_MODEL_TMPL = "{game}/pretrained_models_multiseed/polychrome_ppo_{game}.pt"
EVAL_LOG_JSON_TMPL = "{game}/pretrained_models_multiseed/polychrome_ppo_eval_{game}.json"

GAME_CODE_TO_INFO = {
    "minigrid": {"use_text": False},
}

def cat_kl_from_logits(p_logits: torch.Tensor, q_logits: torch.Tensor) -> torch.Tensor:
    p_log = torch.log_softmax(p_logits, dim=-1)
    q_log = torch.log_softmax(q_logits, dim=-1)
    p = p_log.exp()
    return (p * (p_log - q_log)).sum(-1)

def cat_entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    logp = torch.log_softmax(logits, dim=-1)
    p = logp.exp()
    return -(p * logp).sum(-1)

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def _nan_to_none(x):
    if x is None:
        return None
    try:
        xf = float(x)
    except Exception:
        return None
    if math.isnan(xf) or math.isinf(xf):
        return None
    return xf

def _state_key_from_obs(obs: Dict[str, Any]) -> Tuple[str, int, str]:
    """
    Compact, stable key for N(s,a) from raw MiniGrid obs:
    (mission_text_if_any, direction_int, sha1(image_bytes)[:8])
    """
    mission = obs.get("mission", "")
    direction = obs.get("direction", -1)
    if not isinstance(direction, (int, np.integer)):
        try:
            direction = int(direction)
        except Exception:
            direction = -1
    img = obs.get("image", None)
    try:
        h = hashlib.sha1(img.tobytes()).hexdigest()[:8] if hasattr(img, "tobytes") else "noimg"
    except Exception:
        h = "noimg"
    return (str(mission), int(direction), h)

def _obs_to_feature(prepped: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    Convert prepared obs to a single feature row:
    - Flattened image (C,H,W) -> (C*H*W)
    - Direction one-hot (size 4)
    """
    img: torch.Tensor = prepped["image"]          
    x_img = img.reshape(1, -1).float()          

    dir_tensor: torch.Tensor = prepped.get("direction")
    if dir_tensor is None:
        dir_tensor = torch.zeros((1,), dtype=torch.long, device=img.device)
    dir_tensor = dir_tensor.clamp(0, 3)
    dir_oh = torch.nn.functional.one_hot(dir_tensor, num_classes=4).float()  

    return torch.cat([x_img, dir_oh], dim=1)    

def _batch_to_features(
    obs_list: List[Dict[str, Any]],
    device: torch.device,
    use_text: bool
) -> torch.Tensor:
    """Build a single [B, input_dim] tensor from a list of raw env obs dicts."""
    feats = []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=use_text)
        feats.append(_obs_to_feature(prepped))
    return torch.cat(feats, dim=0) 

# Critic Network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1)

# PPO Agent 
class PPOAgent:
    def __init__(
        self,
        n_actions: int,
        device: torch.device,
        input_dim: int,
        actor_lr: float,
        critic_lr: float,
        ent_coef: float,
        kl_coef: float,
        clip_eps: float,
        max_grad_norm: float,
        use_text: bool,
        pretrained_weights: Optional[str] = None,
        hidden_dim: int = HIDDEN_DIM,
        ucb_coef: float = 0.0,           
    ):
        self.device = device
        self.ent_coef = ent_coef
        self.kl_coef = kl_coef
        self.clip_eps = clip_eps
        self.max_grad_norm = max_grad_norm
        self.use_text = use_text
        self.input_dim = input_dim
        self.ucb_coef = ucb_coef         

        self.actor = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=hidden_dim).to(self.device)
        if pretrained_weights and os.path.isfile(pretrained_weights):
            try:
                state = torch.load(pretrained_weights, map_location=self.device)
                self.actor.load_state_dict(state, strict=True)
                print(f"[Poly-PPO] Loaded pretrained actor weights from: {pretrained_weights}")
            except Exception as e:
                print(f"[Poly-PPO] Failed to load pretrained weights ({pretrained_weights}): {e}")

        self.ref_actor = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=hidden_dim).to(self.device)
        self.ref_actor.load_state_dict(self.actor.state_dict())
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False

        self.critic = ValueNetwork(input_dim=input_dim, hidden_dim=hidden_dim).to(self.device)

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=critic_lr)

    @torch.no_grad()
    def select_action(self, obs: Dict[str, Any], temperature: float = 1.0) -> Tuple[int, float, float]:
        feats = _batch_to_features([obs], self.device, self.use_text)
        logits = self.actor(feats)
        if temperature > 0:
            dist = torch.distributions.Categorical(logits=logits / max(1e-6, float(temperature)))
        else:
            dist = torch.distributions.Categorical(logits=logits)
        a = dist.sample()
        return int(a.item()), float(dist.log_prob(a).item()), float(dist.entropy().item())

    @torch.no_grad()
    def value(self, obs: Dict[str, Any]) -> float:
        feats = _batch_to_features([obs], self.device, self.use_text)
        v = self.critic(feats)
        return float(v.squeeze(0).item())

    def _batchify_flat(self, obs_list: List[Dict[str, Any]]) -> torch.Tensor:
        return _batch_to_features(obs_list, self.device, self.use_text)

    def compute_gae(self, rewards, values, dones, next_value, gamma, gae_lambda):
        advantages = torch.zeros_like(rewards, device=self.device)
        last_gae = 0.0
        values = torch.cat([values, next_value.view(1)])
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * values[t+1] * (1.0 - dones[t]) - values[t]
            last_gae = delta + gae_lambda * gamma * (1.0 - dones[t]) * last_gae
            advantages[t] = last_gae
        returns = advantages + values[:-1]
        return advantages, returns

    def update_from_dataset(
        self,
        on_policy_dataset: Dict,
        env_grid_size: Tuple[int, int],
        ppo_epochs: int,
        minibatch_size: int,
        gamma: float,
        gae_lambda: float,
        value_loss_coef: float,
        polychrome_window: int = 5,
    ):
        self.actor.train()
        self.critic.train()
        self.ref_actor.eval()

        all_obs, all_actions, all_old_logps = [], [], []
        all_rewards, all_dones = [], []
        traj_spans = []
        for trajectories in on_policy_dataset.values():
            for traj in trajectories:
                if not traj:
                    continue
                start = len(all_obs)
                for step in traj:
                    all_obs.append(step['observation'])
                    all_actions.append(step['action'])
                    all_old_logps.append(step['log_prob'])
                    all_rewards.append(step['reward'])
                    all_dones.append(step['done'])
                traj_spans.append((start, len(all_obs)))

        if not all_obs:
            return 0.0, 0.0, 0.0, []

        sa_counts = Counter()
        for ob, a in zip(all_obs, all_actions):
            k = (_state_key_from_obs(ob), int(a))
            sa_counts[k] += 1

        feats_all = self._batchify_flat(all_obs)
        with torch.no_grad():
            values_all = self.critic(feats_all).float() 

        rewards_all = torch.tensor(all_rewards, device=self.device, dtype=torch.float)
        dones_all   = torch.tensor(all_dones,   device=self.device, dtype=torch.float)

        adv_list, ret_list = [], []
        for s, e in traj_spans:
            adv, ret = self.compute_gae(
                rewards_all[s:e], values_all[s:e],
                dones_all[s:e], torch.tensor(0.0, device=self.device),
                gamma, gae_lambda
            )
            adv_list.append(adv)
            ret_list.append(ret)

        advantages = torch.cat(adv_list)
        returns    = torch.cat(ret_list)
        actions    = torch.tensor(all_actions, device=self.device, dtype=torch.long)
        old_logps  = torch.tensor(all_old_logps, device=self.device, dtype=torch.float)

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        raw_obs_full: List[Dict[str, Any]] = list(all_obs)

        aug_obs, aug_actions, aug_logps, aug_adv, aug_returns = [], [], [], [], []
        group_div_history = []

        for _, trajectories in on_policy_dataset.items():
            if len(trajectories) < 4:
                continue
            groups = [random.sample(trajectories, 4) for _ in range(4)]
            group_values = []
            for group in groups:
                total_reward = sum(sum(s['reward'] for s in t) for t in group)
                pos_seqs = []
                for t in group:
                    seq = []
                    for s in t:
                        state = s.get('env_state', None)
                        if state is not None and isinstance(state, (list, tuple)) and len(state) > 0:
                            seq.append(state[0])
                    pos_seqs.append(seq if seq else [(0, 0)])  
                div = cluster_seq_diversity(pos_seqs, grid_size=env_grid_size)
                group_values.append((total_reward / max(1, len(group))) * div)
                group_div_history.append(div)

            if not group_values:
                continue
            baseline = sum(group_values) / len(group_values)

            for group, gval in zip(groups, group_values):
                advantage = float(gval - baseline)
                for traj in group:
                    if not traj:
                        continue
                    for t in range(0, min(polychrome_window, len(traj))):
                        step = traj[t]
                        aug_obs.append(step['observation'])
                        aug_actions.append(step['action'])
                        aug_logps.append(step['log_prob'])
                        aug_adv.append(advantage)
                        with torch.no_grad():
                            f = self._batchify_flat([step['observation']])
                            v = self.critic(f).item()
                            aug_returns.append(v)

        if aug_obs:
            feats_aug = self._batchify_flat(aug_obs)
            feats_all = torch.cat([feats_all, feats_aug], dim=0)
            actions   = torch.cat([actions, torch.tensor(aug_actions, device=self.device, dtype=torch.long)], dim=0)
            old_logps = torch.cat([old_logps, torch.tensor(aug_logps, device=self.device, dtype=torch.float)], dim=0)
            returns   = torch.cat([returns, torch.tensor(aug_returns, device=self.device, dtype=torch.float)], dim=0)

            aug_adv_t = torch.tensor(aug_adv, device=self.device, dtype=torch.float)
            aug_adv_t = (aug_adv_t - aug_adv_t.mean()) / (aug_adv_t.std() + 1e-8)
            advantages = torch.cat([advantages, aug_adv_t], dim=0)

            raw_obs_full.extend(aug_obs)

        assert len(raw_obs_full) == actions.size(0), \
            f"raw_obs_full ({len(raw_obs_full)}) must match actions ({actions.size(0)})"

        actor_losses, critic_losses, kl_hist = [], [], []
        N = actions.size(0)
        for _ in range(ppo_epochs):
            idx = torch.randperm(N, device=self.device)
            for start in range(0, N, minibatch_size):
                mb = idx[start:start+minibatch_size]
                mb_feats   = feats_all[mb]
                mb_actions = actions[mb]
                mb_old_lp  = old_logps[mb]
                mb_ret     = returns[mb]
                mb_adv     = advantages[mb]

                if self.ucb_coef != 0.0:
                    mb_indices = mb.tolist()
                    bonus_vals = []
                    for i in mb_indices:
                        ob_i = raw_obs_full[i]              
                        a_i = int(actions[i].item())
                        Nsa = sa_counts[(_state_key_from_obs(ob_i), a_i)]
                        bonus = self.ucb_coef * min(1.0, (1.0 / max(1, Nsa)) ** 0.5)
                        bonus_vals.append(bonus)
                    mb_bonus = torch.tensor(bonus_vals, device=self.device, dtype=torch.float)
                    mb_adv = mb_adv + mb_bonus

                v = self.critic(mb_feats)
                v_loss = value_loss_coef * nn.functional.mse_loss(v, mb_ret)
                self.critic_opt.zero_grad(set_to_none=True)
                v_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_opt.step()

                logits = self.actor(mb_feats)
                dist   = torch.distributions.Categorical(logits=logits)
                logp   = dist.log_prob(mb_actions)
                ratio  = torch.exp(logp - mb_old_lp)
                ent    = dist.entropy().mean()

                with torch.no_grad():
                    ref_logits = self.ref_actor(mb_feats)
                kl = cat_kl_from_logits(logits, ref_logits).mean()

                surr1 = ratio * mb_adv
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * mb_adv
                a_loss = -torch.min(surr1, surr2).mean() - self.ent_coef * ent + self.kl_coef * kl

                self.actor_opt.zero_grad(set_to_none=True)
                a_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_opt.step()

                actor_losses.append(float(a_loss.item()))
                critic_losses.append(float(v_loss.item()))
                kl_hist.append(float(kl.item()))

        avg_actor = float(np.mean(actor_losses)) if actor_losses else 0.0
        avg_critic = float(np.mean(critic_losses)) if critic_losses else 0.0
        avg_kl = float(np.mean(kl_hist)) if kl_hist else 0.0
        print(f"[Poly-PPO] Update -> actor: {avg_actor:.4f} | critic: {avg_critic:.4f} | KL(ref): {avg_kl:.5f}")
        return avg_actor, avg_critic, avg_kl, group_div_history

# Evaluation 
@torch.no_grad()
def evaluate_policy_simple(
    actor: MLPPolicy,
    env,
    device: torch.device,
    use_text: bool,               
    num_episodes: int,
    max_steps: Optional[int] = None,
    temperature: float = 1.0,
) -> Dict[str, float]:
    rewards, successes, entropies = [], [], []
    all_trajs: List[List[Tuple[int, int]]] = []

    def _get_xy(e) -> Tuple[int, int]:
        try:
            p = getattr(e.unwrapped, "agent_pos", None)
            if p is not None: return int(p[0]), int(p[1])
        except Exception:
            pass
        p = getattr(e, "agent_pos", None)
        return (int(p[0]), int(p[1])) if p is not None else (-1, -1)

    def _get_grid_size(e) -> Tuple[int, int]:
        try:
            g = getattr(e.unwrapped, "grid", None)
            if g is not None: return int(g.width), int(g.height)
        except Exception:
            pass
        return (19, 19)

    grid_size = None

    for _ in range(num_episodes):
        obs, info = env.reset()
        done = False
        ep_r = 0.0
        steps = 0
        ep_ent = []

        traj = []
        if grid_size is None:
            grid_size = _get_grid_size(env)
        traj.append(_get_xy(env))

        while not done:
            feats = _batch_to_features([obs], device, use_text)
            logits = actor(feats)
            if temperature > 0:
                dist = torch.distributions.Categorical(logits=logits / max(1e-6, float(temperature)))
            else:
                dist = torch.distributions.Categorical(logits=logits)
            a = dist.sample()
            ep_ent.append(float(dist.entropy().item()))

            obs, reward, terminated, truncated, info = env.step(int(a.item()))
            if terminated:
                reward = 1.0 - 0.5 * (steps / float(max_steps if max_steps else 1))
            else:
                reward = 0.0
            done = bool(terminated or truncated)
            ep_r += float(reward)
            steps += 1
            traj.append(_get_xy(env))
            if max_steps is not None and steps >= max_steps:
                break

        rewards.append(ep_r)
        successes.append(bool(terminated and ep_r > 0.0))
        entropies.append(float(np.mean(ep_ent)) if ep_ent else 0.0)
        all_trajs.append(traj)

    diversity = 0.0
    if all_trajs:
        diversity = float(cluster_seq_diversity(all_trajs, grid_size=grid_size, grid_dims=(3, 3)))
    return {
        "avg_reward": float(np.mean(rewards)) if rewards else 0.0,
        "success_rate": float(np.mean(successes)) if successes else 0.0,
        "avg_entropy": float(np.mean(entropies)) if entropies else 0.0,
        "diversity": diversity,
    }

def parse_args():
    p = argparse.ArgumentParser(description="Polychromic PPO (MLP) for MiniGrid.")
    p.add_argument("--game_code", type=str, default="minigrid", help="minigrid|open|pickup|goto|unlock|synthseq|bosslevel")
    p.add_argument("--seed", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    p.add_argument("--ppo_epochs", type=int, default=2)
    p.add_argument("--minibatch_size", type=int, default=64)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--clip_epsilon", type=float, default=0.2)
    p.add_argument("--actor_lr", type=float, default=5e-5)
    p.add_argument("--critic_lr", type=float, default=5e-4)
    p.add_argument("--value_loss_coef", type=float, default=0.5)
    p.add_argument("--entropy_coef", type=float, default=0.0)
    p.add_argument("--kl_coef", type=float, default=0.05)
    p.add_argument("--max_grad_norm", type=float, default=0.5)
    p.add_argument("--ucb_coef", type=float, default=0.005,
                   help="λ_ucb for per-(s,a) bonus: λ_ucb * min(1, sqrt(1/N(s,a))). Set 0 to disable.")  

    p.add_argument("--num_vines_at_state", type=int, default=8)
    p.add_argument("--polychrome_window", type=int, default=5)
    p.add_argument("--num_levels", type=int, default=3)

    p.add_argument("--num_training_epochs", type=int, default=1500)
    p.add_argument("--episodes_per_collection", type=int, default=130)  
    p.add_argument("--max_steps_per_episode", type=int, default=100)
    p.add_argument("--eval_interval", type=int, default=100)
    p.add_argument("--heavy_eval_interval", type=int, default=200)
    p.add_argument("--num_eval_episodes", type=int, default=10)
    p.add_argument("--temperature", type=float, default=1)

    p.add_argument("--pretrain_dir", type=str, default="pretrained_models")
    p.add_argument("--pretrain_name_tmpl", type=str, default="pretrain_{game}.pt")
    p.add_argument("--save_path", type=str, default="")
    p.add_argument("--eval_log_path", type=str, default="")

    p.add_argument("--use_text", type=str, default="false", choices=["auto", "true", "false"])
    return p.parse_args()

def _env_bool(key: str, default: bool = False) -> bool:
    v = os.environ.get(key)
    if v is None:
        return default
    return str(v).strip().lower() in {"1", "true", "yes", "on"}

def start_wandb_auto(game_code: str, cfg: Dict[str, Any]):
    if wandb is None:
        print("[W&B] Not installed; proceeding without logging.")
        return None
    if _env_bool("WANDB_DISABLED", False):
        print("[W&B] Disabled via WANDB_DISABLED.")
        return None
    mode = os.environ.get("WANDB_MODE", "online")
    project = os.environ.get("WANDB_PROJECT", "polychromic_ppo")
    entity = os.environ.get("WANDB_ENTITY") or None
    run_name = os.environ.get("WANDB_RUN_NAME", f"poly_ppo_ucb_{game_code}_{int(time.time())}")
    run = wandb.init(project=project, entity=entity, name=run_name, config=cfg, mode=mode, reinit=True)
    if run is not None:
        print(f"[W&B] Initialized in {mode} mode → project={project} run={run.name}")
        wandb.define_metric("global/step")
        wandb.define_metric("train/epoch")
        wandb.define_metric("eval/*", step_metric="train/epoch")
        wandb.define_metric("loss/*", step_metric="global/step")
        wandb.define_metric("regularization/*", step_metric="global/step")
    return run

def wb_log(run, payload: Dict[str, Any], step: Optional[int] = None):
    if run is None:
        return
    clean = {}
    for k, v in payload.items():
        try:
            if v is None:
                clean[k] = None
            else:
                f = float(v)
                if not (math.isnan(f) or math.isinf(f)):
                    clean[k] = v
        except Exception:
            clean[k] = v
    wandb.log(clean, step=step)

# Main
def main():
    args = parse_args()
    device = torch.device(args.device)
    torch.manual_seed(args.seed % (2**32 - 1))
    np.random.seed(args.seed % (2**32 - 1))

    config_filename = f"{args.game_code}.yaml"
    config_path = os.path.join(args.config_dir, config_filename)
    print(f"--- Loading configuration from: {config_path} ---")

    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        return 

    if args.use_text == "auto":
        use_text = GAME_CODE_TO_INFO.get(args.game_code, {"use_text": False})["use_text"]
    else:
        use_text = (args.use_text == "true")

    cfg_for_wandb = {**vars(args)}
    wandb_run = start_wandb_auto(args.game_code, cfg_for_wandb)

    print("--- Creating a pool of train/eval environments ---")
    train_env_pool, eval_env_pool, eval_heldout_env_pool = [], [], []
    eval_logs: List[Dict[str, Any]] = []
    heldout_logs: List[Dict[str, Any]] = []

    train_seeds = config["train_seeds"]
    evaluate_seeds = config["evaluate_seeds"]
    multiseed_seeds = train_seeds
    for s in multiseed_seeds:
        try:
            train_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
            eval_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
        except Exception as e:
            print(f"[WARN] Failed to build env for seed {s}: {e}")

    heldout_seeds = evaluate_seeds[-100:] if len(evaluate_seeds) >= 100 else evaluate_seeds
    if not heldout_seeds:
        heldout_seeds = [args.seed * 1000 + i for i in range(10)]
    for s in heldout_seeds:
        try:
            eval_heldout_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
        except Exception as e:
            print(f"[WARN] Failed to build held-out env for seed {s}: {e}")

    if not train_env_pool:
        print("[WARN] No train envs; creating a fallback.")
        fallback = setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=args.seed)
        train_env_pool = [fallback]
        eval_env_pool  = [setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=args.seed+1)]

    print(f"--- Pools: {len(train_env_pool)} train, {len(eval_env_pool)} eval, {len(eval_heldout_env_pool)} held-out ---")

    # --- Paths & Agent ---
    def _checkpoint_dir_from(save_path: str) -> str:
        base = os.path.dirname(save_path)
        ckpt_dir = os.path.join(base, "checkpoints")
        os.makedirs(ckpt_dir, exist_ok=True)
        return ckpt_dir

    def _save_checkpoint(epoch: int, path: str, agent: PPOAgent, args, extra: Dict[str, Any] = None):
        ckpt = {
            "epoch": epoch,
            "args": vars(args),
            "actor_state": agent.actor.state_dict(),
            "critic_state": agent.critic.state_dict(),
            "actor_opt_state": agent.actor_opt.state_dict(),
            "critic_opt_state": agent.critic_opt.state_dict(),
            "rng_state": {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "torch": torch.get_rng_state(),
                "torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
            },
        }
        if extra:
            ckpt.update(extra)
        torch.save(ckpt, path)

    pretrain_path = os.path.join(args.pretrain_dir, args.pretrain_name_tmpl.format(game=args.game_code))
    save_path = args.save_path or FINE_TUNED_MODEL_TMPL.format(game=args.game_code)
    ckpt_dir = _checkpoint_dir_from(save_path)

    eval_log_path = args.eval_log_path or EVAL_LOG_JSON_TMPL.format(game=args.game_code)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    os.makedirs(os.path.dirname(eval_log_path), exist_ok=True)

    _tmp_obs, _ = train_env_pool[0].reset()
    _tmp_feats = _batch_to_features([_tmp_obs], device, use_text)
    input_dim = int(_tmp_feats.shape[1])

    agent = PPOAgent(
        n_actions=train_env_pool[0].action_space.n,
        device=device,
        input_dim=input_dim,
        actor_lr=args.actor_lr,
        critic_lr=args.critic_lr,
        ent_coef=args.entropy_coef,
        kl_coef=args.kl_coef,
        clip_eps=args.clip_epsilon,
        max_grad_norm=args.max_grad_norm,
        use_text=use_text,
        pretrained_weights=pretrain_path,
        hidden_dim=HIDDEN_DIM,
        ucb_coef=args.ucb_coef,  
    )

    eval_metrics_list = []
    for i, env in enumerate(eval_env_pool):
        metrics = evaluate_policy_simple(
            actor=agent.actor, env=env, device=device, use_text=use_text,
            num_episodes=args.num_eval_episodes, max_steps=args.max_steps_per_episode,
            temperature=args.temperature
        )
        eval_metrics_list.append(metrics)
        print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")
    avg_reward = float(np.mean([m['avg_reward'] for m in eval_metrics_list])) if eval_metrics_list else 0.0
    avg_success = float(np.mean([m['success_rate'] for m in eval_metrics_list])) if eval_metrics_list else 0.0
    avg_diversity = float(np.mean([m['diversity'] for m in eval_metrics_list])) if eval_metrics_list else 0.0
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")
    wb_log(wandb_run, {"heavy_eval/avg_reward": avg_reward, "heavy_eval/success_rate": avg_success * 100.0, "heavy_eval/diversity": avg_diversity})
    wb_eval_payload = {}
    for i, m in enumerate(eval_metrics_list):
        wb_eval_payload[f"heavy_eval_env_{i+1}/avg_reward"] = m['avg_reward']
        wb_eval_payload[f"heavy_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
        wb_eval_payload[f"heavy_eval_env_{i+1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"heavy_eval/epoch": 0, **wb_eval_payload})
    epoch_logs = [{"heavy_eval/epoch": 0, "env_index": i, **m} for i, m in enumerate(eval_metrics_list)]
    eval_logs.extend(epoch_logs)

    print("\n--- Initial Evaluation on held-out envs ---")
    initial_heldout_metrics_list = []
    for i, env in enumerate(eval_heldout_env_pool):
        metrics = evaluate_policy_simple(
            actor=agent.actor, env=env, device=device, use_text=use_text,
            num_episodes=args.num_eval_episodes, max_steps=args.max_steps_per_episode,
            temperature=args.temperature
        )
        initial_heldout_metrics_list.append(metrics)
        print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")
    avg_reward = float(np.mean([m['avg_reward'] for m in initial_heldout_metrics_list])) if initial_heldout_metrics_list else 0.0
    avg_success = float(np.mean([m['success_rate'] for m in initial_heldout_metrics_list])) if initial_heldout_metrics_list else 0.0
    avg_diversity = float(np.mean([m['diversity'] for m in initial_heldout_metrics_list])) if initial_heldout_metrics_list else 0.0
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")
    heldout_logs = [{"epoch": 0, "env_index": i, **m} for i, m in enumerate(initial_heldout_metrics_list)]

    wb_payload = {}
    for i, m in enumerate(initial_heldout_metrics_list):
        wb_payload[f"heldout_eval_env_{i+1}/avg_reward"] = m['avg_reward']
        wb_payload[f"heldout_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
        wb_payload[f"heldout_eval_env_{i+1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"eval/epoch": 0, **wb_payload})
    wb_log(wandb_run, {
        "heldout_eval/avg_reward": avg_reward,
        "heldout_eval/success_rate": avg_success * 100.0,
        "heldout_eval/diversity": avg_diversity,
    })

    for epoch in range(args.num_training_epochs):
        print(f"\n--- Epoch {epoch+1}/{args.num_training_epochs} ---")

        selected_train_envs = random.sample(train_env_pool, min(4, len(train_env_pool)))

        datasets = []
        for env in selected_train_envs:
            dataset = generate_fractal_dataset(
                policy=agent.actor,
                env=env,
                num_vines_at_state=args.num_vines_at_state,
                num_levels=args.num_levels,
                main_rollout_max_steps=args.max_steps_per_episode,
                device=device
            )
            datasets.append(dataset)

        combined_dataset = defaultdict(list)
        for d in datasets:
            for _, trajectories in d.items():
                for traj in trajectories:
                    if not traj:
                        continue
                    k2 = obs_to_key(traj[0]['observation'])
                    combined_dataset[k2].append(traj)
        combined_dataset = dict(combined_dataset)

        for kx, trajs in combined_dataset.items():
            for i, t in enumerate(trajs):
                assert t and obs_to_key(t[0]['observation']) == kx, f"Key/content mismatch at traj #{i}"

        print(f"Combined dataset: {sum(len(v) for v in combined_dataset.values())} trajectories from {len(selected_train_envs)} envs")

        if combined_dataset:
            grid_size = (train_env_pool[0].unwrapped.width, train_env_pool[0].unwrapped.height)
            actor_loss, critic_loss, kl_div, group_div_list = agent.update_from_dataset(
                on_policy_dataset=combined_dataset,
                env_grid_size=grid_size,
                ppo_epochs=args.ppo_epochs,
                minibatch_size=args.minibatch_size,
                gamma=args.gamma,
                gae_lambda=args.gae_lambda,
                value_loss_coef=args.value_loss_coef,
                polychrome_window=args.polychrome_window
            )

            if isinstance(group_div_list, (list, tuple)) and len(group_div_list) > 0:
                avg_group_div = float(np.mean(group_div_list))
                hist_payload = {"train/group_diversity_hist": wandb.Histogram(group_div_list)} if wandb is not None else {}
            else:
                avg_group_div = 0.0
                hist_payload = {}

            wb_log(wandb_run, {
                "train/epoch": epoch + 1,
                "train/actor_loss": _nan_to_none(actor_loss),
                "train/critic_loss": _nan_to_none(critic_loss),
                "train/kl_div": _nan_to_none(kl_div),
                "train/average_group_diversity": _nan_to_none(avg_group_div),
                "regularization/ucb_coef": _nan_to_none(args.ucb_coef),
            **hist_payload}, step=epoch + 1)

        if (epoch + 1) % args.heavy_eval_interval == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch+1} (on all {len(eval_env_pool)} eval envs) ---")
            eval_metrics_list = []
            for i, env in enumerate(eval_env_pool):
                metrics = evaluate_policy_simple(
                    actor=agent.actor, env=env, device=device, use_text=use_text,
                    num_episodes=args.num_eval_episodes, max_steps=args.max_steps_per_episode,
                    temperature=args.temperature
                )
                eval_metrics_list.append(metrics)
                print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")

            avg_reward = float(np.mean([m['avg_reward'] for m in eval_metrics_list]))
            avg_success = float(np.mean([m['success_rate'] for m in eval_metrics_list]))
            avg_diversity = float(np.mean([m['diversity'] for m in eval_metrics_list]))
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")

            wb_eval_payload = {}
            for i, m in enumerate(eval_metrics_list):
                wb_eval_payload[f"heavy_eval_env_{i+1}/avg_reward"] = m['avg_reward']
                wb_eval_payload[f"heavy_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
                wb_eval_payload[f"heavy_eval_env_{i+1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heavy_eval/epoch": epoch + 1, **wb_eval_payload})
            wb_log(wandb_run, {
                "heavy_eval/avg_reward": avg_reward,
                "heavy_eval/success_rate": avg_success * 100.0,
                "heavy_eval/diversity": avg_diversity,
            })

            epoch_logs = [{"heavy_eval/epoch": epoch + 1, "env_index": i, **m} for i, m in enumerate(eval_metrics_list)]
            eval_logs.extend(epoch_logs)

            print(f"\n--- Held-out Evaluation at Epoch {epoch+1} (on all {len(eval_heldout_env_pool)} held-out envs) ---")
            heldout_metrics_list = []
            for i, env in enumerate(eval_heldout_env_pool):
                metrics = evaluate_policy_simple(
                    actor=agent.actor, env=env, device=device, use_text=use_text,
                    num_episodes=args.num_eval_episodes, max_steps=args.max_steps_per_episode,
                    temperature=args.temperature
                )
                heldout_metrics_list.append(metrics)
                print(f"  Env {i+1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate']*100:.2f}%")

            avg_reward = float(np.mean([m['avg_reward'] for m in heldout_metrics_list]))
            avg_success = float(np.mean([m['success_rate'] for m in heldout_metrics_list]))
            avg_diversity = float(np.mean([m['diversity'] for m in heldout_metrics_list]))
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success*100:.2f}%, Diversity: {avg_diversity:.3f}")

            heldout_logs_epoch = [{"epoch": epoch + 1, "env_index": i, **m} for i, m in enumerate(heldout_metrics_list)]
            heldout_logs.extend(heldout_logs_epoch)

            wb_payload = {}
            for i, m in enumerate(heldout_metrics_list):
                wb_payload[f"heldout_eval_env_{i+1}/avg_reward"] = m['avg_reward']
                wb_payload[f"heldout_eval_env_{i+1}/success_rate"] = m['success_rate'] * 100.0
                wb_payload[f"heldout_eval_env_{i+1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heldout_eval/epoch": epoch + 1, **wb_payload})
            wb_log(wandb_run, {
                "heldout_eval/avg_reward": avg_reward,
                "heldout_eval/success_rate": avg_success * 100.0,
                "heldout_eval/diversity": avg_diversity,
            })

        if (epoch + 1) in {250, 500, 1000, 1200}:
            ckpt_path = os.path.join(ckpt_dir, f"poly_ppo_mlp_{args.game_code}_epoch_{epoch+1}.pt")
            _save_checkpoint(epoch + 1, ckpt_path, agent, args)
            print(f"[CKPT] Saved checkpoint → {ckpt_path}")
            if wandb_run is not None:
                try:
                    art = wandb.Artifact(f"poly-ppo-mlp-{args.game_code}-ckpt-ep{epoch+1}", type="model")
                    art.add_file(ckpt_path)
                    wandb.log_artifact(art)
                except Exception as e:
                    print(f"[W&B] Checkpoint artifact logging failed: {e}")

    print("\nTraining complete.")
    torch.save(agent.actor.state_dict(), save_path)
    print(f"Final policy saved to {save_path}")

    with open(eval_log_path, "w") as f:
        json.dump(eval_logs, f, indent=2)
    print(f"Eval logs saved to: {eval_log_path}")

    if wandb_run is not None:
        try:
            art = wandb.Artifact(f"poly-ppo-mlp-{args.game_code}-actor", type="model")
            art.add_file(save_path)
            art.add_file(eval_log_path)
            wandb.log_artifact(art)
        except Exception as e:
            print(f"[W&B] Artifact logging failed: {e}")
        wandb.finish()

    print("--- Closing envs ---")
    for env in train_env_pool: env.close()
    for env in eval_env_pool: env.close()
    for env in eval_heldout_env_pool: env.close()

if __name__ == "__main__":
    main()
