"""
eval_fixedenvs_widecsv_like_evalpy.py

Wide CSV:
  metric,<xpid0>,<xpid1>,...
  solved_rate:ENV_A,...
  test_returns:ENV_A,...
  solved_rate:ENV_B,...
  test_returns:ENV_B,...

- Column name = xpid (run directory name).
- If duplicate, auto-suffix _2, _3, ...
- Appends to --wide_csv (creates if missing).

Adds screenshot saving:
- If --screenshot_dir exists -> DO NOT create screenshots.
- Else create it and save ONE png per env:
    <screenshot_dir>/<SANITIZED_ENV_NAME>.png
- Screenshots are taken from RAW env (gym/gym_make) with reset+render,
  NO ParallelAdversarialVecEnv, NO reset_agent dependency.

Screenshot style requirement (MultiGrid/MiniGrid):
- Prefer the "human-style" render with highlight/agent (like your first figure).
- We attempt to render via env.render(mode="human") then read from env.window.
- Fallback to env.render(mode="rgb_array") if window capture isn't available.

Also fixes:
- add --prefix handling
- remove undefined num_seeds usage
"""

import sys
import os
import csv
import json
import argparse
import fnmatch
import re
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from baselines.logger import HumanOutputFormat
from tqdm import tqdm

from envs.registration import make as gym_make
from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
    MultiGridFullyObsWrapper,
    CarRacingWrapper,
)
from util import DotDict, str2bool, make_agent, is_discrete_actions, save_images

try:
    import gymnasium as gym
except ImportError:
    import gym

# Ensure env registrations are imported (same style as eval.py)
from envs.multigrid.maze import *  # noqa: F401,F403
from envs.multigrid.crossing import *  # noqa
from envs.multigrid.fourrooms import *  # noqa
from envs.multigrid.mst_maze import *  # noqa
from envs.box2d import *  # noqa
from envs.bipedalwalker import *  # noqa

import torchvision.utils as vutils


# ---------------------- env name presets ----------------------


def _get_bipedal_env_names() -> List[str]:
    return [
        "BipedalWalker-v3",
        "BipedalWalkerHardcore-v3",
        "BipedalWalker-Med-Stairs-v0",
        "BipedalWalker-Med-PitGap-v0",
        "BipedalWalker-Med-StumpHeight-v0",
        "BipedalWalker-Med-Roughness-v0",
    ]


def _get_poet_rose_env_names() -> List[str]:
    return [
        f"BipedalWalker-POET-Rose-{id_}-v0" for id_ in ["1a", "1b", "2a", "2b", "3a", "3b"]
    ]


def _get_zs_minigrid_env_names() -> List[str]:
    return [
        "MultiGrid-SixteenRooms-v0",
        "MultiGrid-SixteenRoomsFewerDoors-v0",
        "MultiGrid-Labyrinth-v0",
        "MultiGrid-Labyrinth2-v0",
        "MultiGrid-Maze-v0",
        "MultiGrid-Maze2-v0",
        "MultiGrid-LargeCorridor-v0",
        "MultiGrid-PerfectMazeMedium-v0",
        "MultiGrid-PerfectMazeLarge-v0",
        "MultiGrid-PerfectMazeXL-v0",
    ]


def _get_env_names_from_benchmark(benchmark: Optional[str]) -> List[str]:
    if benchmark is None:
        return []
    if benchmark == "bipedal":
        return _get_bipedal_env_names()
    if benchmark == "poetrose":
        return _get_poet_rose_env_names()
    if benchmark == "maze":
        return _get_zs_minigrid_env_names()
    raise ValueError(f"Unknown benchmark: {benchmark}")


# ---------------------- WIDE CSV helpers ----------------------


def _read_wide_csv(path: str) -> Tuple[List[str], Dict[str, List[str]]]:
    with open(path, "r", newline="") as f:
        reader = csv.reader(f)
        header = next(reader)
        if not header or header[0] != "metric":
            raise ValueError(
                f"Bad wide CSV header in {path}: first column must be 'metric'."
            )
        rows: Dict[str, List[str]] = {}
        for parts in reader:
            if not parts:
                continue
            metric = parts[0]
            vals = parts[1:]
            need = max(0, (len(header) - 1) - len(vals))
            if need > 0:
                vals = vals + [""] * need
            if len(vals) > (len(header) - 1):
                vals = vals[: (len(header) - 1)]
            rows[metric] = vals
    return header, rows


def _safe_unique_col_name(existing: List[str], desired: str) -> str:
    if desired not in existing:
        return desired
    k = 2
    while f"{desired}_{k}" in existing:
        k += 1
    return f"{desired}_{k}"


def append_xpid_results_to_wide_csv(
    *,
    out_csv: str,
    col_name: str,
    env_names: List[str],
    solved_rate: Dict[str, float],
    mean_return: Dict[str, float],
):
    out_csv = os.path.expandvars(os.path.expanduser(out_csv))
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    desired_metrics: List[str] = []
    for env in env_names:
        desired_metrics.append(f"solved_rate:{env}")
        desired_metrics.append(f"test_returns:{env}")

    if os.path.exists(out_csv):
        header, rows = _read_wide_csv(out_csv)

        col_name = _safe_unique_col_name(header[1:], col_name)
        new_header = ["metric"] + header[1:] + [col_name]

        for m in desired_metrics:
            if m not in rows:
                rows[m] = [""] * (len(header) - 1)

        old_w = len(header) - 1
        for m in list(rows.keys()):
            vals = rows[m]
            if len(vals) < old_w:
                vals = vals + [""] * (old_w - len(vals))
            elif len(vals) > old_w:
                vals = vals[:old_w]
            rows[m] = vals

        run_vals: Dict[str, str] = {}
        for env in env_names:
            run_vals[f"solved_rate:{env}"] = str(float(solved_rate.get(env, 0.0)))
            run_vals[f"test_returns:{env}"] = str(float(mean_return.get(env, 0.0)))

        for m in rows:
            rows[m] = rows[m] + [run_vals.get(m, "")]

        tmp_path = out_csv + ".tmp"
        with open(tmp_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(new_header)
            for m in desired_metrics:
                w.writerow([m] + rows[m])
            for m in rows:
                if m in desired_metrics:
                    continue
                w.writerow([m] + rows[m])
        os.replace(tmp_path, out_csv)
    else:
        tmp_path = out_csv + ".tmp"
        with open(tmp_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["metric", col_name])
            for env in env_names:
                w.writerow([f"solved_rate:{env}", str(float(solved_rate.get(env, 0.0)))])
                w.writerow([f"test_returns:{env}", str(float(mean_return.get(env, 0.0)))])
        os.replace(tmp_path, out_csv)


# ---------------------- screenshot helpers ----------------------


def _sanitize_filename(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9._-]+", "_", s)


def _make_raw_env(env_name: str, record_video: bool = False, **kwargs):
    # Do NOT use ParallelAdversarialVecEnv here.
    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        try:
            env = gym_make(env_name)
        except Exception:
            env = gym.make(env_name)

    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")

    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    return env


def _reset_compat(env, seed: int):
    # Gymnasium style
    try:
        env.reset(seed=seed)
        return
    except TypeError:
        pass
    except Exception:
        pass

    # Gym old style
    try:
        env.seed(seed)
    except Exception:
        pass
    env.reset()


def _render_rgb_fallback(env) -> Optional[np.ndarray]:
    """Last resort: rgb_array or plain render(). Might look like your second figure."""
    frame = None
    try:
        frame = env.render(mode="rgb_array")
    except Exception:
        frame = None
    if frame is None:
        try:
            frame = env.render()
        except Exception:
            frame = None
    if frame is None:
        return None

    frame = np.asarray(frame)
    if frame.ndim != 3:
        return None
    if frame.shape[-1] == 4:
        frame = frame[..., :3]
    if frame.dtype != np.uint8:
        frame = np.clip(frame, 0, 255).astype(np.uint8)
    return frame


def _render_multigrid_human_and_grab(env, highlight: bool = True) -> Optional[np.ndarray]:
    """
    Goal: produce "first figure" style:
      - gray background
      - agent triangle visible
      - highlight mask visible

    Strategy:
      - call env.render(mode='human', highlight=True/False)
      - then try to grab framebuffer from env.window

    This relies on the MultiGrid window implementation exposing one of:
      - env.window.get_frame()
      - env.window.get_img()
      - env.window.render_to_rgb_array()
      - env.window.img (numpy)
    """
    # drive the window render
    try:
        # MultiGridEnv.render signature: render(mode='human', close=False, highlight=True, tile_size=...)
        env.render(mode="human", highlight=bool(highlight))
    except TypeError:
        try:
            env.render(mode="human")
        except Exception:
            pass
    except Exception:
        pass

    win = getattr(env, "window", None)
    if win is None:
        return None

    # common grab methods/fields
    for attr in ["get_frame", "get_img", "render_to_rgb_array", "get_rgb", "get_image"]:
        fn = getattr(win, attr, None)
        if callable(fn):
            try:
                frame = fn()
                if frame is None:
                    continue
                frame = np.asarray(frame)
                if frame.ndim != 3:
                    continue
                if frame.shape[-1] == 4:
                    frame = frame[..., :3]
                if frame.dtype != np.uint8:
                    frame = np.clip(frame, 0, 255).astype(np.uint8)
                return frame
            except Exception:
                continue

    # try direct field
    for field in ["img", "image", "frame", "buffer"]:
        if hasattr(win, field):
            try:
                frame = getattr(win, field)
                frame = np.asarray(frame)
                if frame.ndim != 3:
                    continue
                if frame.shape[-1] == 4:
                    frame = frame[..., :3]
                if frame.dtype != np.uint8:
                    frame = np.clip(frame, 0, 255).astype(np.uint8)
                return frame
            except Exception:
                continue

    return None


def _render_rgb(env, env_name: str) -> np.ndarray:
    """
    Rendering policy:
      - For MultiGrid/MiniGrid: try human+window capture first (desired style).
      - Else: fallback rgb_array.
    """
    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")

    if is_multigrid:
        frame = _render_multigrid_human_and_grab(env, highlight=True)
        if frame is not None:
            return frame

    frame = _render_rgb_fallback(env)
    if frame is None:
        raise RuntimeError("render() returned None / cannot capture frame.")
    return frame


def maybe_save_screenshots_for_envs(
    env_names: List[str],
    screenshot_dir: Optional[str],
    seed: int,
    wrapper_kwargs: Dict[str, Any],
):
    """
    Requirement:
      - If screenshot_dir exists => do nothing.
      - Else create it and save one png per env, named by env name.
    """
    if screenshot_dir is None or str(screenshot_dir).strip() == "":
        return

    screenshot_dir = os.path.expandvars(os.path.expanduser(str(screenshot_dir)))

    # KEY REQUIREMENT: exists => skip everything (no overwrite)
    if os.path.isdir(screenshot_dir):
        print(
            f"[Screenshot] dir exists, skip all screenshots: {screenshot_dir}", flush=True
        )
        return

    os.makedirs(screenshot_dir, exist_ok=True)
    print(f"[Screenshot] creating screenshots under: {screenshot_dir}", flush=True)

    for env_name in env_names:
        out_path = os.path.join(screenshot_dir, f"{_sanitize_filename(env_name)}.png")
        env = _make_raw_env(env_name, record_video=False, **wrapper_kwargs)
        try:
            img = env.render()
            save_images(img, out_path, normalize=True, channels_first=False)
            print(f"[Screenshot] saved: {out_path}", flush=True)
        finally:
            try:
                env.close()
            except Exception:
                pass


# ---------------------- argument parsing ----------------------


def parse_args():
    p = argparse.ArgumentParser(
        description="Eval fixed env list -> append WIDE CSV by xpid column"
    )

    p.add_argument("--base_path", type=str, default="~/logs/dcd")
    p.add_argument("--xpid", type=str, default="latest")

    p.add_argument(
        "--prefix",
        type=str,
        default=None,
        help="Evaluate all xpids matching this prefix (like eval.py).",
    )
    p.add_argument(
        "--max_seeds",
        type=int,
        default=None,
        help="Max number of matched xpids when using --prefix.",
    )

    p.add_argument(
        "--env_names", type=str, default="", help="CSV string of evaluation envs."
    )
    p.add_argument(
        "--benchmark", type=str, default=None, choices=["maze", "bipedal", "poetrose"]
    )

    p.add_argument("--model_tar", type=str, default="model")
    p.add_argument(
        "--model_name", type=str, default="agent", choices=["agent", "adversary_agent"]
    )

    p.add_argument("--num_processes", type=int, default=2)
    p.add_argument("--max_num_processes", type=int, default=10)
    p.add_argument("--num_episodes", type=int, default=100)

    p.add_argument("--deterministic", type=str2bool, nargs="?", const=True, default=True)
    p.add_argument("--verbose", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--render", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--record_video", type=str2bool, nargs="?", const=True, default=False)

    p.add_argument("--device", type=str, default="cpu")

    # wrapper fallbacks (if absent in meta.json)
    p.add_argument("--frame_stack", type=int, default=1)
    p.add_argument("--grayscale", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument(
        "--use_global_policy", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument(
        "--use_global_critic", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument("--crop_frame", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--num_action_repeat", type=int, default=8)

    # output
    p.add_argument(
        "--wide_csv",
        type=str,
        default="./eval_results/bw_wide.csv",
        help="Wide CSV path to append to.",
    )

    # screenshots
    p.add_argument(
        "--screenshot_dir",
        type=str,
        default="",
        help="If this directory exists -> skip screenshots. Else create and save one png per env.",
    )
    p.add_argument(
        "--screenshot_seed", type=int, default=0, help="Seed used for screenshot reset()."
    )

    return p.parse_args()


def _resolve_xpids(base_path: str, xpid: str, prefix: Optional[str]) -> List[str]:
    base_path = os.path.expandvars(os.path.expanduser(base_path))
    if prefix is None:
        return [xpid]

    all_xpids = fnmatch.filter(os.listdir(base_path), f"{prefix}*")
    filter_re = re.compile(r".*_[0-9]*$")
    xpids = [xp for xp in all_xpids if filter_re.match(xp)] or list(all_xpids)
    xpids.sort()
    return xpids


# ---------------------- env build / wrap ----------------------


def make_env(env_name: str, record_video: bool = False, **kwargs):
    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        env = gym_make(env_name)

    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_car_racing = env_name.startswith("CarRacing")

    if is_car_racing:
        grayscale = kwargs.get("grayscale", False)
        num_action_repeat = kwargs.get("num_action_repeat", 8)
        nstack = kwargs.get("frame_stack", 4)
        crop = kwargs.get("crop_frame", False)

        env = CarRacingWrapper(
            env=env,
            grayscale=grayscale,
            reward_shaping=False,
            num_action_repeat=num_action_repeat,
            nstack=nstack,
            crop=crop,
            eval_=True,
        )

        if record_video:
            from gym.wrappers.monitor import Monitor  # type: ignore

            env = Monitor(env, "videos/", force=True)
            print("Recording video!", flush=True)

    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    return env


def wrap_venv(venv, env_name: str, device: str = "cpu"):
    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_bipedal = env_name.startswith("BipedalWalker")

    obs_key = None
    scale = None
    if is_multigrid:
        obs_key = "image"
        scale = 10.0

    transpose_order = [2, 0, 1]
    if is_bipedal:
        transpose_order = None

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecPreprocessImageWrapper(
        venv=venv,
        obs_key=obs_key,
        transpose_order=transpose_order,
        scale=scale,
        device=device,
    )
    return venv


def build_parallel_env(
    env_name: str, num_processes: int, device: str, record_video: bool, **kwargs
):
    make_fns = []
    for _ in range(int(num_processes)):

        def _make(env_name_=env_name, record_video_=record_video, kwargs_=kwargs):
            return make_env(env_name_, record_video=record_video_, **kwargs_)

        make_fns.append(_make)

    venv = ParallelAdversarialVecEnv(make_fns, adversary=False, is_eval=True)
    venv = wrap_venv(venv, env_name=env_name, device=device)
    return venv


# ---------------------- evaluator ----------------------


class Evaluator:
    def __init__(
        self,
        env_names: List[str],
        num_processes: int,
        num_episodes: int,
        device: str,
        record_video: bool,
        **kwargs,
    ):
        self.env_names = env_names
        self.num_processes = int(num_processes)
        self.num_episodes = int(num_episodes)
        self.device = device
        self.kwargs = kwargs

        if len(env_names) > 0 and env_names[0].startswith("BipedalWalker"):
            self.solved_threshold = 230.0
        else:
            self.solved_threshold = 0.0

        self.venv: Dict[str, Any] = {}
        for env_name in env_names:
            self.venv[env_name] = build_parallel_env(
                env_name=env_name,
                num_processes=self.num_processes,
                device=self.device,
                record_video=record_video,
                **kwargs,
            )

        self.is_discrete_actions = is_discrete_actions(self.venv[env_names[0]])

    def close(self):
        for v in self.venv.values():
            v.close()

    def evaluate(self, agent, deterministic: bool, show_progress: bool, render: bool):
        solved_rate: Dict[str, float] = {}
        mean_return: Dict[str, float] = {}

        for env_name, venv in self.venv.items():
            returns: List[float] = []
            solved = 0

            obs = venv.reset()
            rh = torch.zeros(
                self.num_processes,
                agent.algo.actor_critic.recurrent_hidden_state_size,
                device=self.device,
            )
            if (
                agent.algo.actor_critic.is_recurrent
                and agent.algo.actor_critic.rnn.arch == "lstm"
            ):
                rh = (rh, torch.zeros_like(rh))
            masks = torch.ones(self.num_processes, 1, device=self.device)

            pbar = tqdm(total=self.num_episodes, disable=not show_progress)

            while len(returns) < self.num_episodes:
                with torch.no_grad():
                    _, action, _, rh = agent.act(
                        obs, rh, masks, deterministic=bool(deterministic)
                    )

                action_np = action.cpu().numpy()
                if not self.is_discrete_actions:
                    action_np = agent.process_action(action_np)

                obs, _reward, done, infos = venv.step(action_np)

                masks = torch.tensor(
                    [[0.0] if d else [1.0] for d in done],
                    dtype=torch.float32,
                    device=self.device,
                )

                for i, info in enumerate(infos):
                    if "episode" in info:
                        r = float(info["episode"]["r"])
                        returns.append(r)
                        if r > self.solved_threshold:
                            solved += 1
                        pbar.update(1)

                        if getattr(agent, "is_recurrent", False):
                            if isinstance(rh, tuple):
                                rh[0][i].zero_()
                                rh[1][i].zero_()
                            else:
                                rh[i].zero_()

                        if len(returns) >= self.num_episodes:
                            break

                if render:
                    venv.render_to_screen()

            pbar.close()

            solved_rate[env_name] = solved / float(len(returns) or 1)
            mean_return[env_name] = float(np.mean(returns)) if returns else 0.0

        return solved_rate, mean_return


# ---------------------- checkpoint loading ----------------------


def load_agent_once(
    base_path: str,
    xpid: str,
    model_tar: str,
    model_name: str,
    dummy_env_name: str,
    device: str,
    wrapper_kwargs: Dict[str, Any],
):
    base_path = os.path.expandvars(os.path.expanduser(base_path))
    xpid_dir = os.path.join(base_path, xpid)
    meta_json_path = os.path.join(xpid_dir, "meta.json")
    checkpoint_path = os.path.join(xpid_dir, f"{model_tar}.tar")

    if not os.path.exists(meta_json_path):
        raise FileNotFoundError(f"meta.json not found: {meta_json_path}")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"checkpoint not found: {checkpoint_path}")

    with open(meta_json_path, "r") as f:
        xpid_flags = DotDict(json.load(f)["args"])

    dummy_venv = build_parallel_env(
        env_name=dummy_env_name,
        num_processes=1,
        device=device,
        record_video=False,
        **wrapper_kwargs,
    )

    agent = make_agent(name="agent", env=dummy_venv, args=xpid_flags, device=device)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if isinstance(checkpoint, dict) and "runner_state_dict" in checkpoint:
        agent.algo.actor_critic.load_state_dict(
            checkpoint["runner_state_dict"]["agent_state_dict"][model_name]
        )
    else:
        agent.algo.actor_critic.load_state_dict(checkpoint)

    dummy_venv.close()
    return agent, xpid_flags, checkpoint_path


# ---------------------- main ----------------------


def main():
    os.environ["OMP_NUM_THREADS"] = "1"

    display = None
    if sys.platform.startswith("linux"):
        try:
            import pyvirtualdisplay  # type: ignore

            display = pyvirtualdisplay.Display(visible=0, size=(1400, 900), color_depth=24)
            display.start()
        except Exception:
            display = None

    args = DotDict(vars(parse_args()))
    args.num_processes = min(int(args.num_processes), int(args.num_episodes))

    # Determine env names
    if args.benchmark is not None:
        env_names = _get_env_names_from_benchmark(args.benchmark)
    elif str(args.env_names).strip():
        env_names = [x.strip() for x in str(args.env_names).split(",") if x.strip()]
    else:
        raise ValueError("Must provide either --benchmark or --env_names")

    base_path = os.path.expandvars(os.path.expanduser(args.base_path))
    xpids = _resolve_xpids(base_path, args.xpid, args.prefix)
    if args.max_seeds is not None:
        xpids = xpids[: int(args.max_seeds)]
    if len(xpids) == 0:
        raise ValueError("No xpids matched.")

    # chunking across envs to cap total processes
    num_envs = len(env_names)
    if num_envs * args.num_processes > args.max_num_processes:
        chunk_size = int(args.max_num_processes) // int(args.num_processes)
        chunk_size = max(1, chunk_size)
    else:
        chunk_size = num_envs
    num_chunks = int(np.ceil(num_envs / chunk_size))

    if args.record_video:
        num_chunks = 1
        chunk_size = 1
        args.num_processes = 1

    wrapper_defaults: Dict[str, Any] = dict(
        frame_stack=int(args.frame_stack),
        grayscale=bool(args.grayscale),
        use_global_policy=bool(args.use_global_policy),
        use_global_critic=bool(args.use_global_critic),
        crop_frame=bool(args.crop_frame),
        num_action_repeat=int(args.num_action_repeat),
    )

    # # --- screenshots (global, independent of xpid) ---
    # maybe_save_screenshots_for_envs(
    #     env_names=env_names,
    #     screenshot_dir=str(args.screenshot_dir),
    #     seed=int(args.screenshot_seed),
    #     wrapper_kwargs=wrapper_defaults,
    # )

    # Evaluate each xpid and append one wide column per xpid
    for xpid in xpids:
        xpid_dir = os.path.join(base_path, xpid)
        checkpoint_path = os.path.join(xpid_dir, f"{args.model_tar}.tar")
        if not os.path.exists(checkpoint_path):
            print(f"[Skip] No model path {checkpoint_path}", flush=True)
            continue

        try:
            agent, xpid_flags, _ = load_agent_once(
                base_path=args.base_path,
                xpid=xpid,
                model_tar=args.model_tar,
                model_name=args.model_name,
                dummy_env_name=env_names[0],
                device=str(args.device),
                wrapper_kwargs=wrapper_defaults,
            )
        except Exception as e:
            print(f"[Skip] Failed to load xpid={xpid}: {e}", flush=True)
            continue

        wrapper_kwargs = dict(
            frame_stack=int(
                getattr(xpid_flags, "frame_stack", wrapper_defaults["frame_stack"])
            ),
            grayscale=bool(getattr(xpid_flags, "grayscale", wrapper_defaults["grayscale"])),
            use_global_policy=bool(
                getattr(
                    xpid_flags, "use_global_policy", wrapper_defaults["use_global_policy"]
                )
            ),
            use_global_critic=bool(
                getattr(
                    xpid_flags, "use_global_critic", wrapper_defaults["use_global_critic"]
                )
            ),
            crop_frame=bool(
                getattr(xpid_flags, "crop_frame", wrapper_defaults["crop_frame"])
            ),
            num_action_repeat=int(
                getattr(
                    xpid_flags, "num_action_repeat", wrapper_defaults["num_action_repeat"]
                )
            ),
        )

        solved_rate_all: Dict[str, float] = {}
        mean_return_all: Dict[str, float] = {}

        for i in range(num_chunks):
            start_idx = i * chunk_size
            env_names_ = env_names[start_idx : start_idx + chunk_size]

            evaluator = Evaluator(
                env_names_,
                num_processes=int(args.num_processes),
                num_episodes=int(args.num_episodes),
                device=str(args.device),
                record_video=bool(args.record_video),
                **wrapper_kwargs,
            )

            sr, mr = evaluator.evaluate(
                agent,
                deterministic=bool(args.deterministic),
                show_progress=bool(args.verbose),
                render=bool(args.render),
            )
            solved_rate_all.update(sr)
            mean_return_all.update(mr)
            evaluator.close()

        output_results: Dict[str, str] = {}
        for env in env_names:
            key_sr = f"solved_rate:{env}"
            key_mr = f"test_returns:{env}"
            output_results[key_sr] = f"{solved_rate_all.get(env, 0.0):.4f}"
            output_results[key_mr] = f"{mean_return_all.get(env, 0.0):.2f}"
            print(
                f"[Eval:{xpid}] {env}: solved_rate={output_results[key_sr]}, mean_return={output_results[key_mr]}",
                flush=True,
            )
        HumanOutputFormat(sys.stdout).writekvs(output_results)

        col_name = os.path.basename(base_path.rstrip("/"))
        append_xpid_results_to_wide_csv(
            out_csv=str(args.wide_csv),
            col_name=col_name,
            env_names=env_names,
            solved_rate=solved_rate_all,
            mean_return=mean_return_all,
        )
        print(f"[OK] Appended '{col_name}' to: {args.wide_csv}", flush=True)

    if display:
        try:
            display.stop()
        except Exception:
            pass


if __name__ == "__main__":
    main()
