# reinforce_ucb.py 
import os
import math
import json
import time
import argparse
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import Counter
import hashlib
import torch.nn.functional as F

try:
    import wandb  
except Exception:
    wandb = None

from ..utils import (
    prepare_obs,
    cluster_seq_diversity,
    HIDDEN_DIM,
)
from ..train_utils import setup_environment as setup_environment_multiseed
from pretrain import MLPPolicy  
FINE_TUNED_MODEL_TMPL = "{game}/pretrained_models_multiseed/reinforce_{game}.pt"
EVAL_LOG_JSON_TMPL = "{game}/pretrained_models_multiseed/reinforce_eval_{game}.json"

GAME_CODE_TO_INFO = {
    "minigrid": {"use_text": False},
}


def _obs_to_feature(prepped: Dict[str, torch.Tensor]) -> torch.Tensor:
    img: torch.Tensor = prepped["image"]          
    x_img = img.reshape(1, -1).float()           

    dir_tensor: torch.Tensor = prepped.get("direction")
    if dir_tensor is None:
        dir_tensor = torch.zeros((1,), dtype=torch.long, device=img.device)
    dir_tensor = dir_tensor.clamp(0, 3)          
    dir_oh = F.one_hot(dir_tensor, num_classes=4).float()  
    return torch.cat([x_img, dir_oh], dim=1)    

def _batch_to_features(
    obs_list: List[Dict[str, Any]], device: torch.device
) -> torch.Tensor:
    feats = []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=False)
        feats.append(_obs_to_feature(prepped))
    return torch.cat(feats, dim=0)  



# Critic Network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1)  

def cat_kl_from_logits(p_logits: torch.Tensor, q_logits: torch.Tensor) -> torch.Tensor:
    p_log = torch.log_softmax(p_logits, dim=-1)
    q_log = torch.log_softmax(q_logits, dim=-1)
    p = p_log.exp()
    return (p * (p_log - q_log)).sum(-1)

def success_from_info_or_reward(terminated: bool, truncated: bool, info: Dict[str, Any], ep_reward: float) -> bool:
    if isinstance(info, dict) and "success" in info:
        return bool(info["success"])
    return bool(terminated and ep_reward > 0.0)

def _nan_to_none(x):
    if x is None: return None
    try:
        xf = float(x)
    except Exception:
        return None
    if math.isnan(xf) or math.isinf(xf): return None
    return x

class ReinforceAgent:
    def __init__(
        self,
        n_actions: int,
        device: torch.device,
        input_dim: int,
        actor_lr: float,
        critic_lr: float,
        ent_coef: float,
        max_grad_norm: float,
        pretrained_weights: Optional[str] = None,
        kl_coef: float = 0.01,
        ucb_coef: float = 0.005, 
    ):
        self.device = device
        self.ent_coef = ent_coef
        self.max_grad_norm = max_grad_norm
        self.kl_coef = kl_coef
        self.ucb_coef = ucb_coef

        self.actor = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(self.device)
        if pretrained_weights and os.path.isfile(pretrained_weights):
            try:
                state = torch.load(pretrained_weights, map_location=self.device)
                self.actor.load_state_dict(state, strict=False)
                print(f"[REINFORCE] Loaded pretrained actor weights from: {pretrained_weights}")
            except Exception as e:
                print(f"[REINFORCE] WARNING: couldn't load pretrained weights: {e}")
        else:
            print(f"[REINFORCE] No pretrained weights at: {pretrained_weights} (actor from scratch)")

        self.ref_actor = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(self.device)
        self.ref_actor.load_state_dict(self.actor.state_dict())
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False

        self.critic = ValueNetwork(input_dim=input_dim, hidden_dim=HIDDEN_DIM).to(device)

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=critic_lr)

    @torch.no_grad()
    def select_action(self, obs: Dict[str, Any], temperature: float = 1.0) -> Tuple[int, float, float]:
        feats = _batch_to_features([obs], device=self.device)
        logits = self.actor(feats) / max(1e-6, float(temperature))
        dist = torch.distributions.Categorical(logits=logits)
        a = dist.sample()
        logp = dist.log_prob(a)
        return int(a.item()), float(logp.item()), float(dist.entropy().item())

    @torch.no_grad()
    def value(self, obs: Dict[str, Any]) -> float:
        feats = _batch_to_features([obs], device=self.device)
        v = self.critic(feats)
        return float(v.squeeze(0).item())

    def _state_key(self, obs: Dict[str, Any]) -> Tuple[str, int, str]:
        """
        Compact, stable key for counting N(s,a).
        Uses mission text (if any), direction (if present as int), and hash of image bytes.
        """
        mission = obs.get("mission", "")
        direction = int(obs.get("direction", -1)) if isinstance(obs.get("direction"), (int,)) else -1
        img = obs.get("image", None)
        try:
            h = hashlib.sha1(np.asarray(img).tobytes()).hexdigest()[:8]
        except Exception:
            h = "noimg"
        return (mission, direction, h)

    def update(self,
               rollouts: List[Dict[str, Any]],
               gamma: float,
               value_loss_coef: float) -> Tuple[float, float]:
        """
        REINFORCE with baseline + optional UCB bonus + KL to ref.
        """
        obs_list = [r["obs"] for r in rollouts]
        actions = torch.tensor([r["action"] for r in rollouts], dtype=torch.long, device=self.device)
        rewards = torch.tensor([r["reward"] for r in rollouts], dtype=torch.float, device=self.device)
        dones = torch.tensor([r["done"] for r in rollouts], dtype=torch.float, device=self.device)

        sa_counts = Counter()
        for ob, a in zip(obs_list, actions.tolist()):
            k = (self._state_key(ob), int(a))
            sa_counts[k] += 1

        returns = torch.zeros_like(rewards)
        running_return = 0.0
        for t in reversed(range(len(rewards))):
            running_return = rewards[t] + gamma * running_return * (1.0 - dones[t])
            returns[t] = running_return

        feats = _batch_to_features(obs_list, device=self.device)
        values = self.critic(feats)
        advantages = returns - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        if self.ucb_coef > 0.0:
            with torch.no_grad():
                ucb_bonus = []
                for ob, a in zip(obs_list, actions.tolist()):
                    N = sa_counts[(self._state_key(ob), int(a))]
                    b = self.ucb_coef * min(1.0, (1.0 / max(1, N)) ** 0.5)
                    ucb_bonus.append(b)
                ucb_bonus = torch.tensor(ucb_bonus, device=self.device, dtype=torch.float)
            advantages = advantages + ucb_bonus
        else:
            ucb_bonus = None

        v_loss = value_loss_coef * nn.functional.mse_loss(values, returns)
        self.critic_opt.zero_grad(set_to_none=True)
        v_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.critic_opt.step()

        logits = self.actor(feats)
        dist = torch.distributions.Categorical(logits=logits)
        logp = dist.log_prob(actions)
        ent = dist.entropy().mean()

        with torch.no_grad():
            ref_logits = self.ref_actor(feats)
        kl = cat_kl_from_logits(logits, ref_logits).mean()

        a_loss = -(logp * advantages.detach()).mean() - self.ent_coef * ent + self.kl_coef * kl

        self.actor_opt.zero_grad(set_to_none=True)
        a_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.actor_opt.step()

        print(
            f"[REINFORCE] Update -> actor: {a_loss.item():.4f} | "
            f"critic: {v_loss.item():.4f} | KL(ref): {kl.item():.5f}"
            + (f" | UCB(avg): {float(ucb_bonus.mean().item()):.4f}" if ucb_bonus is not None else "")
        )
        return float(a_loss.item()), float(v_loss.item())


# Evaluation
@torch.no_grad()
def evaluate_policy_simple(
    actor: nn.Module,
    env,
    device: torch.device,
    use_text: bool,  
    num_episodes: int,
    max_steps: Optional[int] = None,
    temperature: float = 1.0,
) -> Dict[str, float]:
    rewards, successes, entropies = [], [], []
    all_trajs: List[List[Tuple[int, int]]] = []

    def _get_xy(e) -> Tuple[int, int]:
        try:
            p = getattr(e.unwrapped, "agent_pos", None)
            if p is not None: return int(p[0]), int(p[1])
        except Exception:
            pass
        p = getattr(e, "agent_pos", None)
        return (int(p[0]), int(p[1])) if p is not None else (-1, -1)

    def _get_grid_size(e) -> Tuple[int, int]:
        try:
            g = getattr(e.unwrapped, "grid", None)
            if g is not None: return int(g.width), int(g.height)
        except Exception:
            pass
        return (19, 19)

    grid_size = None

    for _ in range(num_episodes):
        obs, info = env.reset()
        done = False
        ep_r = 0.0
        steps = 0
        ep_ent = []

        traj = []
        if grid_size is None:
            grid_size = _get_grid_size(env)
        traj.append(_get_xy(env))

        while not done:
            prepped = prepare_obs(obs, device=device, use_text=False)
            feats = _obs_to_feature(prepped)
            logits = actor(feats) / max(1e-6, float(temperature))
            dist = torch.distributions.Categorical(logits=logits)
            a = dist.sample()
            ep_ent.append(float(dist.entropy().item()))

            obs, reward, terminated, truncated, info = env.step(int(a.item()))
            if terminated:
                assert reward > 0, "Reward should be positive when done."
                reward = 1.0 - 0.5 * (steps / max_steps)
            else:
                assert reward == 0, "Reward should be zero when not done."
                reward = 0.0
            done = bool(terminated or truncated)
            ep_r += float(reward)
            steps += 1
            traj.append(_get_xy(env))
            if max_steps is not None and steps >= max_steps:
                break

        rewards.append(ep_r)
        successes.append(success_from_info_or_reward(terminated, truncated, info, ep_r))
        entropies.append(float(np.mean(ep_ent)) if ep_ent else 0.0)
        all_trajs.append(traj)

    diversity = 0.0
    if all_trajs:
        diversity = float(cluster_seq_diversity(all_trajs, grid_size=grid_size, grid_dims=(3, 3)))

    return {
        "avg_reward": float(np.mean(rewards)) if rewards else 0.0,
        "success_rate": float(np.mean(successes)) if successes else 0.0,
        "avg_entropy": float(np.mean(entropies)) if entropies else 0.0,
        "diversity": diversity,
    }


# Argparse
def parse_args():
    p = argparse.ArgumentParser(description="REINFORCE(+UCB) with MLP policy for MiniGrid.")
    p.add_argument("--game_code", type=str, default="open", help="open|pickup|goto|unlock|synthseq|bosslevel")
    p.add_argument("--seed", type=int, default=100)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--actor_lr", type=float, default=5e-5)
    p.add_argument("--critic_lr", type=float, default=5e-4)
    p.add_argument("--value_loss_coef", type=float, default=0.5)
    p.add_argument("--entropy_coef", type=float, default=0.0)
    p.add_argument("--max_grad_norm", type=float, default=0.5)
    p.add_argument("--kl_coef", type=float, default=0.01) 

    p.add_argument("--num_training_epochs", type=int, default=1500)
    p.add_argument("--episodes_per_collection", type=int, default=130)
    p.add_argument("--max_steps_per_episode", type=int, default=100)
    p.add_argument("--eval_interval", type=int, default=100)
    p.add_argument("--heavy_eval_interval", type=int, default=200)
    p.add_argument("--num_eval_episodes", type=int, default=10)
    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--ucb_coef", type=float, default=0.005, help="λ_UCB. Set 0 to disable. Same coefficient used in PPO.")

    p.add_argument("--pretrain_dir", type=str, default="pretrained_models")
    p.add_argument("--pretrain_name_tmpl", type=str, default="pretrain_{game}.pt")
    p.add_argument("--save_path", type=str, default="")
    p.add_argument("--eval_log_path", type=str, default="")

    p.add_argument("--use_text", type=str, default="false", choices=["auto", "true", "false"])

    return p.parse_args()


def _env_bool(key: str, default: bool = False) -> bool:
    v = os.environ.get(key)
    if v is None:
        return default
    return str(v).strip().lower() in {"1", "true", "yes", "on"}

def start_wandb_auto(game_code: str, cfg: Dict[str, Any]):
    if wandb is None:
        print("[W&B] Not installed; proceeding without logging.")
        return None
    if _env_bool("WANDB_DISABLED", False):
        print("[W&B] Disabled via WANDB_DISABLED.")
        return None

    mode = os.environ.get("WANDB_MODE", "online")
    project = os.environ.get("WANDB_PROJECT", "opendoor-reinforce-sweep")
    entity = os.environ.get("WANDB_ENTITY") or None
    group = os.environ.get("WANDB_GROUP") or None
    notes = os.environ.get("WANDB_NOTES") or None
    tags_raw = os.environ.get("WANDB_TAGS", "")
    tags = [t.strip() for t in tags_raw.split(',') if t.strip()] or None
    run_name = os.environ.get("WANDB_RUN_NAME", f"reinforce-{game_code}-{int(time.time())}")

    run = wandb.init(
        mode=mode, project=project, entity=entity, name=run_name, group=group,
        tags=tags, notes=notes, config={**cfg, "wandb_mode": mode, "wandb_project": project},
        reinit=True,
    )
    if run is not None:
        wandb.define_metric("global/step")
        wandb.define_metric("train/episode")
        wandb.define_metric("eval/*", step_metric="train/episode")
        wandb.define_metric("loss/*", step_metric="global/step")
        print(f"[W&B] Initialized in {mode} mode → project={project} run={run.name}")
    return run

def wb_log(run, payload: Dict[str, Any], step: Optional[int] = None):
    if run is None:
        return
    clean = {}
    for k, v in payload.items():
        try:
            if v is None:
                clean[k] = None
            else:
                f = float(v)
                if not (math.isnan(f) or math.isinf(f)):
                    clean[k] = v
        except Exception:
            clean[k] = v
    wandb.log(clean, step=step)


def main():
    args = parse_args()
    device = torch.device(args.device)
    torch.manual_seed(args.seed % (2 ** 32 - 1))
    np.random.seed(args.seed % (2 ** 32 - 1))
    random.seed(args.seed % (2 ** 32 - 1))
    config_filename = f"{args.game_code}.yaml"
    config_path = os.path.join(args.config_dir, config_filename)
    print(f"--- Loading configuration from: {config_path} ---")

    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        return 

    cfg_for_wandb = {**vars(args), "use_text_resolved": False}
    wandb_run = start_wandb_auto(args.game_code, cfg_for_wandb)

    print("--- Creating a pool of train/eval environments... ---")
    train_env_pool, eval_env_pool, eval_heldout_env_pool = [], [], []

    eval_logs: List[Dict[str, Any]] = []   
    heldout_logs: List[Dict[str, Any]] = []   

    train_seeds = config["train_seeds"]
    evaluate_seeds = config["evaluate_seeds"]

    multiseed_seeds = train_seeds

    for s in multiseed_seeds:
        train_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))
        eval_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))

    heldout_seeds = evaluate_seeds[-100:]
    for s in heldout_seeds:
        eval_heldout_env_pool.append(setup_environment_multiseed(args.game_code, max_steps=args.max_steps_per_episode, meta_fixed_seed=s))

    print(f"--- Environment pools created. Total: {len(train_env_pool)} train, {len(eval_env_pool)} eval. ---")

    tmp_obs, _ = train_env_pool[0].reset()
    tmp_feats = _batch_to_features([tmp_obs], device)
    input_dim = int(tmp_feats.shape[1])

    def _checkpoint_dir_from(save_path: str) -> str:
        base = os.path.dirname(save_path)
        ckpt_dir = os.path.join(base, "checkpoints")
        os.makedirs(ckpt_dir, exist_ok=True)
        return ckpt_dir

    def _save_checkpoint(epoch: int, path: str, agent: ReinforceAgent, args, extra: Dict[str, Any] = None):
        ckpt = {
            "epoch": epoch,
            "args": vars(args),
            "actor_state": agent.actor.state_dict(),
            "critic_state": agent.critic.state_dict(),
            "actor_opt_state": agent.actor_opt.state_dict(),
            "critic_opt_state": agent.critic_opt.state_dict(),
            "rng_state": {
                "python": random.getstate(),
                "numpy": np.random.get_state(),
                "torch": torch.get_rng_state(),
                "torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
            },
        }
        if extra:
            ckpt.update(extra)
        torch.save(ckpt, path)

    pretrain_path = os.path.join(args.pretrain_dir, args.pretrain_name_tmpl.format(game=args.game_code))
    save_path = args.save_path or FINE_TUNED_MODEL_TMPL.format(game=args.game_code)
    ckpt_dir = _checkpoint_dir_from(save_path)

    eval_log_path = args.eval_log_path or EVAL_LOG_JSON_TMPL.format(game=args.game_code)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    os.makedirs(os.path.dirname(eval_log_path), exist_ok=True)

    agent = ReinforceAgent(
        n_actions=train_env_pool[0].action_space.n,
        device=device,
        input_dim=input_dim,
        actor_lr=args.actor_lr,
        critic_lr=args.critic_lr,
        ent_coef=args.entropy_coef,
        max_grad_norm=args.max_grad_norm,
        pretrained_weights=pretrain_path,
        kl_coef=args.kl_coef,
        ucb_coef=args.ucb_coef
    )
    eval_metrics_list = []
    for i, env in enumerate(eval_env_pool):
        metrics = evaluate_policy_simple(
            agent.actor, env, device, False, args.num_eval_episodes,
            args.max_steps_per_episode, args.temperature
        )
        eval_metrics_list.append(metrics)
        print(f" Env {i + 1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate'] * 100:.2f}%")
    avg_reward = np.mean([m['avg_reward'] for m in eval_metrics_list])
    avg_success = np.mean([m['success_rate'] for m in eval_metrics_list])
    avg_diversity = np.mean([m['diversity'] for m in eval_metrics_list])
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success * 100:.2f}%, Diversity: {avg_diversity:.3f}")
    wb_log(wandb_run, {"heavy_eval/avg_reward": avg_reward, "heavy_eval/success_rate": avg_success * 100.0,
                       "heavy_eval/diversity": avg_diversity})
    wb_eval_payload = {}
    for i, m in enumerate(eval_metrics_list):
        wb_eval_payload[f"heavy_eval_env_{i + 1}/avg_reward"] = m['avg_reward']
        wb_eval_payload[f"heavy_eval_env_{i + 1}/success_rate"] = m['success_rate'] * 100.0
        wb_eval_payload[f"heavy_eval_env_{i + 1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"heavy_eval/epoch": 0, **wb_eval_payload})
    epoch_logs = [{"heavy_eval/epoch": 0, "env_index": eval_env_pool.index(env), **metrics} for env, metrics in
                  zip(eval_env_pool, eval_metrics_list)]
    eval_logs.extend(epoch_logs)

    print("\n--- Initial Evaluation of Pre-trained Policy (held-out envs) ---")
    initial_heldout_metrics_list = []
    for i, env in enumerate(eval_heldout_env_pool):
        metrics = evaluate_policy_simple(
            agent.actor, env, device, False, args.num_eval_episodes,
            args.max_steps_per_episode, args.temperature
        )
        initial_heldout_metrics_list.append(metrics)
        print(f" Env {i + 1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate'] * 100:.2f}%")
    avg_reward = np.mean([m['avg_reward'] for m in initial_heldout_metrics_list])
    avg_success = np.mean([m['success_rate'] for m in initial_heldout_metrics_list])
    avg_diversity = np.mean([m['diversity'] for m in initial_heldout_metrics_list])
    print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success * 100:.2f}%, Diversity: {avg_diversity:.3f}")

    heldout_logs = [{"epoch": 0, "env_index": eval_heldout_env_pool.index(env), **metrics} for env, metrics in
                    zip(eval_heldout_env_pool, initial_heldout_metrics_list)]
    wb_payload = {}
    for i, m in enumerate(initial_heldout_metrics_list):
        wb_payload[f"heldout_eval_env_{i + 1}/avg_reward"] = m['avg_reward']
        wb_payload[f"heldout_eval_env_{i + 1}/success_rate"] = m['success_rate'] * 100.0
        wb_payload[f"heldout_eval_env_{i + 1}/diversity"] = m['diversity']
    wb_log(wandb_run, {"eval/epoch": 0, **wb_payload})
    wb_log(wandb_run, {
        "heldout_eval/avg_reward": avg_reward,
        "heldout_eval/success_rate": avg_success * 100.0,
        "heldout_eval/diversity": avg_diversity,
    })

    for epoch in range(args.num_training_epochs):
        print(f"\n--- Epoch {epoch + 1}/{args.num_training_epochs} ---")
        selected_train_envs = random.sample(train_env_pool, 4)

        rollouts: List[Dict[str, Any]] = []
        for env in selected_train_envs:
            episodes_collected = 0
            while episodes_collected < args.episodes_per_collection // len(selected_train_envs):
                obs, _ = env.reset()
                done = False
                steps = 0
                while not done and steps < args.max_steps_per_episode:
                    action, logp, _ = agent.select_action(obs, temperature=args.temperature)
                    next_obs, reward, terminated, truncated, info = env.step(action)

                    if terminated:
                        assert reward > 0, "Env should give positive reward on success termination."
                        shaped_r = 1.0 - 0.5 * (steps / float(args.max_steps_per_episode))
                    else:
                        assert reward == 0, "Env should give zero reward until success."
                        shaped_r = 0.0

                    rollouts.append({
                        "obs": obs,
                        "action": action,
                        "log_prob": logp,
                        "reward": shaped_r,
                        "done": bool(terminated or truncated),
                        "next_obs": next_obs,
                    })

                    obs = next_obs
                    steps += 1
                episodes_collected += 1

        if rollouts:
            actor_loss, critic_loss = agent.update(
                rollouts=rollouts,
                gamma=args.gamma,
                value_loss_coef=args.value_loss_coef,
            )
            wb_log(wandb_run, {
                "train/epoch": epoch + 1,
                "train/actor_loss": _nan_to_none(actor_loss),
                "train/critic_loss": _nan_to_none(critic_loss),
            }, step=epoch + 1)

        if (epoch + 1) % args.heavy_eval_interval == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch + 1} (on all {len(eval_env_pool)} eval envs) ---")
            eval_metrics_list = []
            for i, env in enumerate(eval_env_pool):
                metrics = evaluate_policy_simple(
                    agent.actor, env, device, False, args.num_eval_episodes,
                    args.max_steps_per_episode, args.temperature
                )
                eval_metrics_list.append(metrics)
                print(f" Env {i + 1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate'] * 100:.2f}%")

            avg_reward = np.mean([m['avg_reward'] for m in eval_metrics_list])
            avg_success = np.mean([m['success_rate'] for m in eval_metrics_list])
            avg_diversity = np.mean([m['diversity'] for m in eval_metrics_list])
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success * 100:.2f}%, Diversity: {avg_diversity:.3f}")

            wb_eval_payload = {}
            for i, m in enumerate(eval_metrics_list):
                wb_eval_payload[f"heavy_eval_env_{i + 1}/avg_reward"] = m['avg_reward']
                wb_eval_payload[f"heavy_eval_env_{i + 1}/success_rate"] = m['success_rate'] * 100.0
                wb_eval_payload[f"heavy_eval_env_{i + 1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heavy_eval/epoch": epoch + 1, **wb_eval_payload})
            wb_log(wandb_run, {
                "heavy_eval/avg_reward": avg_reward,
                "heavy_eval/success_rate": avg_success * 100.0,
                "heavy_eval/diversity": avg_diversity,
            })
            epoch_logs = [{"heavy_eval/epoch": epoch + 1, "env_index": eval_env_pool.index(env), **metrics} for
                          env, metrics in zip(eval_env_pool, eval_metrics_list)]
            eval_logs.extend(epoch_logs)

        if epoch % args.heavy_eval_interval == 0:
            print(f"\n--- Held-out Evaluation at Epoch {epoch + 1} (on all {len(eval_heldout_env_pool)} held-out eval envs) ---")
            heldout_metrics_list = []
            for i, env in enumerate(eval_heldout_env_pool):
                metrics = evaluate_policy_simple(
                    agent.actor, env, device, False, args.num_eval_episodes,
                    args.max_steps_per_episode, args.temperature
                )
                heldout_metrics_list.append(metrics)
                print(f" Env {i + 1} -> Avg Reward: {metrics['avg_reward']:.3f}, Success: {metrics['success_rate'] * 100:.2f}%")

            avg_reward = np.mean([m['avg_reward'] for m in heldout_metrics_list])
            avg_success = np.mean([m['success_rate'] for m in heldout_metrics_list])
            avg_diversity = np.mean([m['diversity'] for m in heldout_metrics_list])
            print(f"OVERALL -> Avg Reward: {avg_reward:.3f}, Success: {avg_success * 100:.2f}%, Diversity: {avg_diversity:.3f}")

            heldout_logs_epoch = [
                {"epoch": epoch + 1, "env_index": eval_heldout_env_pool.index(env), **metrics}
                for env, metrics in zip(eval_heldout_env_pool, heldout_metrics_list)
            ]
            heldout_logs.extend(heldout_logs_epoch)

            wb_payload = {}
            for i, m in enumerate(heldout_metrics_list):
                wb_payload[f"heldout_eval_env_{i + 1}/avg_reward"] = m['avg_reward']
                wb_payload[f"heldout_eval_env_{i + 1}/success_rate"] = m['success_rate'] * 100.0
                wb_payload[f"heldout_eval_env_{i + 1}/diversity"] = m['diversity']
            wb_log(wandb_run, {"heldout_eval/epoch": epoch + 1, **wb_payload})
            wb_log(wandb_run, {
                "heldout_eval/avg_reward": avg_reward,
                "heldout_eval/success_rate": avg_success * 100.0,
                "heldout_eval/diversity": avg_diversity,
            })

        if (epoch + 1) in {250, 500, 1000, 1200}:
            ckpt_path = os.path.join(ckpt_dir, f"reinforce_{args.game_code}_epoch_{epoch + 1}.pt")
            _save_checkpoint(epoch + 1, ckpt_path, agent, args)
            print(f"[CKPT] Saved checkpoint → {ckpt_path}")
            if wandb_run is not None:
                try:
                    art = wandb.Artifact(f"reinforce-{args.game_code}-ckpt-ep{epoch + 1}", type="model")
                    art.add_file(ckpt_path)
                    wandb.log_artifact(art)
                except Exception as e:
                    print(f"[W&B] Checkpoint artifact logging failed: {e}")

    print("\nTraining complete.")
    torch.save(agent.actor.state_dict(), save_path)
    print(f"Final policy saved to {save_path}")
    with open(eval_log_path, "w") as f:
        json.dump(eval_logs, f, indent=2)
    print(f"Eval logs saved to: {eval_log_path}")
    if wandb_run is not None:
        try:
            art = wandb.Artifact(f"reinforce-{args.game_code}-actor", type="model")
            art.add_file(save_path)
            art.add_file(eval_log_path)
            wandb.log_artifact(art)
        except Exception as e:
            print(f"[W&B] Artifact logging failed: {e}")
        wandb.finish()

    print("--- Closing all environments... ---")
    for env in train_env_pool: env.close()
    for env in eval_env_pool: env.close()
    for env in eval_heldout_env_pool: env.close()

if __name__ == "__main__":
    main()
