# -*- coding: utf-8 -*-
"""
eval_bw_fixedseeds_widecsv_xpid.py

BipedalWalker fixed-seed evaluation that:
  1) Locates checkpoint via (base_path, xpid) exactly like the original eval.py:
       xpid_dir = {base_path}/{xpid}
       meta.json = xpid_dir/meta.json
       checkpoint = xpid_dir/{model_tar}.tar
  2) Loads ONE checkpoint ONCE (per xpid)
  3) Evaluates across a list of BipedalWalker envs with fixed seeds using:
       - num_processes = 16
       - num_seeds = 128 seeds: [seed_start, ..., seed_start+num_seeds-1]
       - episodes_per_seed repeats per seed (default 1)
       - deterministic policy (default True)
  4) Appends results as a NEW COLUMN into a "wide" CSV:

      metric,<col0>,<col1>,...
      solved_rate:ENV_A,...
      test_returns:ENV_A,...
      solved_rate:ENV_B,...
      test_returns:ENV_B,...

Column name defaults to the xpid (i.e., last directory name). If duplicate, auto-suffix _2, _3, ...

Assumptions (same as your codebase):
  - ParallelAdversarialVecEnv supports venv.reset_with_seeds(seeds_batch).
  - agent.act(obs, rh, masks, deterministic=...) exists.
  - For continuous actions, agent.process_action(...) exists.
"""

import sys
import os
import csv
import json
import argparse
import fnmatch
import re
from typing import Any, Dict, List, Tuple

import numpy as np
import torch

from envs.registration import make as gym_make
from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
    MultiGridFullyObsWrapper,
    CarRacingWrapper,
)
from util import DotDict, str2bool, make_agent, is_discrete_actions

try:
    import gymnasium as gym
except ImportError:
    import gym


from envs.bipedalwalker import *


# ---------------------- env name presets ----------------------


def _get_bipedal_env_names() -> List[str]:
    return [
        "BipedalWalker-v3",
        "BipedalWalkerHardcore-v3",
        "BipedalWalker-Med-Stairs-v0",
        "BipedalWalker-Med-PitGap-v0",
        "BipedalWalker-Med-StumpHeight-v0",
        "BipedalWalker-Med-Roughness-v0",
    ]


# ---------------------- env build (match your Evaluator) ----------------------


def _make_eval_env(env_name: str, device: str, record_video: bool, **kwargs):
    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_car_racing = env_name.startswith("CarRacing")
    is_bipedal = env_name.startswith("BipedalWalker")
    is_lava = env_name.startswith("Lava")

    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        env = gym_make(env_name)

    if is_car_racing:
        grayscale = kwargs.get("grayscale", False)
        num_action_repeat = kwargs.get("num_action_repeat", 8)
        nstack = kwargs.get("frame_stack", 4)
        crop = kwargs.get("crop_frame", False)

        env = CarRacingWrapper(
            env=env,
            grayscale=grayscale,
            reward_shaping=False,
            num_action_repeat=num_action_repeat,
            nstack=nstack,
            crop=crop,
            eval_=True,
        )

        if record_video:
            from gym.wrappers.monitor import Monitor  # type: ignore

            env = Monitor(env, "videos/", force=True)
            print("Recording video!", flush=True)

    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    return env, is_multigrid, is_car_racing, is_bipedal, is_lava


def _build_parallel_env(
    env_name: str,
    num_processes: int,
    device: str,
    record_video: bool,
    **kwargs,
):
    make_fns = []
    for _ in range(num_processes):

        def _make(env_name_=env_name, record_video_=record_video, kwargs_=kwargs):
            env, *_ = _make_eval_env(
                env_name_, device=device, record_video=record_video_, **kwargs_
            )
            return env

        make_fns.append(_make)

    venv = ParallelAdversarialVecEnv(make_fns, adversary=False, is_eval=True)

    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_bipedal = env_name.startswith("BipedalWalker")
    is_lava = env_name.startswith("Lava")

    obs_key = None
    scale = None
    transpose_order = [2, 0, 1]  # channels-first for image obs

    if is_multigrid:
        obs_key = "image"
        scale = 10.0
    if is_bipedal or is_lava:
        transpose_order = None

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecPreprocessImageWrapper(
        venv=venv,
        obs_key=obs_key,
        transpose_order=transpose_order,
        scale=scale,
        device=device,
    )
    return venv


# ---------------------- agent load (xpid-based) ----------------------


def _load_flags_and_checkpoint_paths(
    base_path: str, xpid: str, model_tar: str
) -> Tuple[str, str, str]:
    base_path = os.path.expandvars(os.path.expanduser(base_path))
    xpid_dir = os.path.join(base_path, xpid)
    meta_json_path = os.path.join(xpid_dir, "meta.json")
    checkpoint_path = os.path.join(xpid_dir, f"{model_tar}.tar")
    return xpid_dir, meta_json_path, checkpoint_path


def load_agent_from_xpid_once(
    *,
    base_path: str,
    xpid: str,
    model_tar: str,
    model_name: str,
    dummy_env_name: str,
    device: str,
    wrapper_kwargs: Dict[str, Any],
) -> Tuple[Any, DotDict, str]:
    """
    Load meta.json + checkpoint once for this xpid, returning (agent, flags, checkpoint_path).
    """
    _xpid_dir, meta_json_path, checkpoint_path = _load_flags_and_checkpoint_paths(
        base_path, xpid, model_tar
    )

    if not os.path.exists(meta_json_path):
        raise FileNotFoundError(f"meta.json not found: {meta_json_path}")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"checkpoint not found: {checkpoint_path}")

    with open(meta_json_path, "r") as f:
        xpid_flags = DotDict(json.load(f)["args"])

    # dummy venv (1 env) for spaces/shapes, keep wrappers consistent
    dummy_venv = _build_parallel_env(
        env_name=dummy_env_name,
        num_processes=1,
        device=device,
        record_video=False,
        **wrapper_kwargs,
    )

    agent = make_agent(name="agent", env=dummy_venv, args=xpid_flags, device=device)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if isinstance(checkpoint, dict) and "runner_state_dict" in checkpoint:
        agent.algo.actor_critic.load_state_dict(
            checkpoint["runner_state_dict"]["agent_state_dict"][model_name]
        )
    else:
        agent.algo.actor_critic.load_state_dict(checkpoint)

    dummy_venv.close()
    return agent, xpid_flags, checkpoint_path


# ---------------------- evaluation core (fixed seeds) ----------------------


def _init_recurrent_state(agent, num_envs: int, device: torch.device):
    actor_critic = agent.algo.actor_critic
    hidden_size = actor_critic.recurrent_hidden_state_size
    rh = torch.zeros(num_envs, hidden_size, device=device)

    is_recurrent = getattr(actor_critic, "is_recurrent", False) or getattr(
        agent, "is_recurrent", False
    )
    arch = getattr(getattr(actor_critic, "rnn", None), "arch", None)
    if is_recurrent and arch == "lstm":
        rh = (rh, torch.zeros_like(rh))
    return rh


def evaluate_env_fixed_seeds(
    *,
    env_name: str,
    venv,
    agent,
    num_processes: int,
    seed_start: int,
    num_seeds: int,
    episodes_per_seed: int,
    device: str,
    deterministic: bool,
    render_first_env: bool = False,
) -> Tuple[float, float]:
    """
    Returns:
      solved_rate, mean_return
    Both computed across ALL evaluation episodes = num_seeds * episodes_per_seed.
    """
    assert (
        num_seeds % num_processes == 0
    ), f"num_seeds({num_seeds}) must be divisible by num_processes({num_processes})."

    solved_threshold = 230.0 if env_name.startswith("BipedalWalker") else 0.0

    device_t = torch.device(device)
    is_disc = is_discrete_actions(venv)

    all_seeds = list(range(seed_start, seed_start + num_seeds))
    batches = num_seeds // num_processes

    total_returns: List[float] = []
    total_solved = 0

    for _rep in range(episodes_per_seed):
        for b in range(batches):
            seeds_batch = all_seeds[b * num_processes : (b + 1) * num_processes]
            num_envs = len(seeds_batch)

            obs = venv.reset_with_seeds(seeds_batch)

            rh = _init_recurrent_state(agent, num_envs=num_envs, device=device_t)
            masks = torch.ones(num_envs, 1, device=device_t)

            finished = [0 for _ in range(num_envs)]  # one episode per slot in this batch
            while sum(finished) < num_envs:
                with torch.no_grad():
                    _, action, _, rh = agent.act(
                        obs, rh, masks, deterministic=bool(deterministic)
                    )

                action_np = action.cpu().numpy()
                if not is_disc:
                    action_np = agent.process_action(action_np)

                obs, _reward, done, infos = venv.step(action_np)

                masks = torch.tensor(
                    [[0.0] if d else [1.0] for d in done],
                    dtype=torch.float32,
                    device=device_t,
                )

                if render_first_env:
                    venv.render_to_screen()

                for i, info in enumerate(infos):
                    if finished[i] >= 1:
                        continue
                    if "episode" not in info:
                        continue

                    ret = float(info["episode"]["r"])
                    total_returns.append(ret)
                    if ret >= solved_threshold:
                        total_solved += 1
                    finished[i] += 1

                    # zero hidden state for that slot
                    if isinstance(rh, tuple):
                        rh[0][i].zero_()
                        rh[1][i].zero_()
                    else:
                        if getattr(agent, "is_recurrent", False):
                            rh[i].zero_()

    mean_return = float(np.mean(total_returns)) if total_returns else 0.0
    solved_rate = (total_solved / float(len(total_returns))) if total_returns else 0.0
    return solved_rate, mean_return


# ---------------------- wide CSV (append one column per xpid) ----------------------


def _read_wide_csv(path: str) -> Tuple[List[str], Dict[str, List[str]]]:
    with open(path, "r", newline="") as f:
        reader = csv.reader(f)
        header = next(reader)
        if not header or header[0] != "metric":
            raise ValueError(
                f"Bad wide CSV header in {path}: first column must be 'metric'."
            )
        rows: Dict[str, List[str]] = {}
        for parts in reader:
            if not parts:
                continue
            metric = parts[0]
            vals = parts[1:]
            # normalize length to existing columns
            need = max(0, (len(header) - 1) - len(vals))
            if need > 0:
                vals = vals + [""] * need
            if len(vals) > (len(header) - 1):
                vals = vals[: (len(header) - 1)]
            rows[metric] = vals
    return header, rows


def _safe_unique_col_name(existing: List[str], desired: str) -> str:
    if desired not in existing:
        return desired
    k = 2
    while f"{desired}_{k}" in existing:
        k += 1
    return f"{desired}_{k}"


def append_xpid_results_to_wide_csv(
    *,
    out_csv: str,
    col_name: str,
    env_names: List[str],
    solved_rate: Dict[str, float],
    mean_return: Dict[str, float],
):
    out_csv = os.path.expandvars(os.path.expanduser(out_csv))
    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    desired_metrics: List[str] = []
    for env in env_names:
        desired_metrics.append(f"solved_rate:{env}")
        desired_metrics.append(f"test_returns:{env}")

    if os.path.exists(out_csv):
        header, rows = _read_wide_csv(out_csv)

        col_name = _safe_unique_col_name(header[1:], col_name)
        new_header = ["metric"] + header[1:] + [col_name]

        # ensure desired metrics exist
        for m in desired_metrics:
            if m not in rows:
                rows[m] = [""] * (len(header) - 1)

        # normalize all existing rows to old width
        old_w = len(header) - 1
        for m in list(rows.keys()):
            vals = rows[m]
            if len(vals) < old_w:
                vals = vals + [""] * (old_w - len(vals))
            elif len(vals) > old_w:
                vals = vals[:old_w]
            rows[m] = vals

        # build this column's values
        run_vals: Dict[str, str] = {}
        for env in env_names:
            run_vals[f"solved_rate:{env}"] = str(float(solved_rate.get(env, 0.0)))
            run_vals[f"test_returns:{env}"] = str(float(mean_return.get(env, 0.0)))

        # append the new column value
        for m in rows:
            rows[m] = rows[m] + [run_vals.get(m, "")]

        tmp_path = out_csv + ".tmp"
        with open(tmp_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(new_header)
            # stable ordering: desired first, then others
            for m in desired_metrics:
                w.writerow([m] + rows[m])
            for m in rows:
                if m in desired_metrics:
                    continue
                w.writerow([m] + rows[m])
        os.replace(tmp_path, out_csv)

    else:
        tmp_path = out_csv + ".tmp"
        with open(tmp_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["metric", col_name])
            for env in env_names:
                w.writerow([f"solved_rate:{env}", str(float(solved_rate.get(env, 0.0)))])
                w.writerow([f"test_returns:{env}", str(float(mean_return.get(env, 0.0)))])
        os.replace(tmp_path, out_csv)


# ---------------------- args / main ----------------------


def parse_args():
    p = argparse.ArgumentParser(
        description="BipedalWalker fixed-seed eval -> append wide CSV (xpid-based)"
    )

    p.add_argument(
        "--base_path",
        type=str,
        default="~/logs/dcd",
        help="Base path containing xpid directories.",
    )
    p.add_argument(
        "--xpid",
        type=str,
        default="latest",
        help="Experiment ID (directory name) to evaluate.",
    )
    p.add_argument(
        "--prefix",
        type=str,
        default=None,
        help="Evaluate all xpids matching this prefix (like old eval.py).",
    )
    p.add_argument(
        "--max_seeds",
        type=int,
        default=None,
        help="Max number of matched xpids when using --prefix.",
    )

    p.add_argument(
        "--benchmark",
        type=str,
        default="bipedal",
        choices=["bipedal"],
        help="Benchmark env preset.",
    )
    p.add_argument(
        "--env_names",
        type=str,
        default="",
        help="CSV string of env names (overrides benchmark).",
    )

    p.add_argument(
        "--model_tar",
        type=str,
        default="model",
        help="Checkpoint base name (loads {model_tar}.tar).",
    )
    p.add_argument(
        "--model_name",
        type=str,
        default="agent",
        choices=["agent", "adversary_agent"],
        help="Which agent to evaluate inside runner_state_dict.",
    )

    # eval protocol
    p.add_argument("--num_processes", type=int, default=16)
    p.add_argument("--num_seeds", type=int, default=128)
    p.add_argument("--seed_start", type=int, default=0)
    p.add_argument("--episodes_per_seed", type=int, default=1)
    p.add_argument("--deterministic", type=str2bool, nargs="?", const=True, default=True)

    # runtime
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--render", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--record_video", type=str2bool, nargs="?", const=True, default=False)

    # wrapper fallbacks (if meta.json lacks them)
    p.add_argument("--frame_stack", type=int, default=1)
    p.add_argument("--grayscale", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument(
        "--use_global_policy", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument(
        "--use_global_critic", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument("--crop_frame", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--num_action_repeat", type=int, default=8)

    # output
    p.add_argument(
        "--wide_csv",
        type=str,
        default="./eval_results/bw_wide.csv",
        help="Wide CSV path to append to.",
    )

    return p.parse_args()


def _resolve_xpids(base_path: str, xpid: str, prefix: str | None) -> List[str]:
    base_path = os.path.expandvars(os.path.expanduser(base_path))
    if prefix is None:
        return [xpid]
    all_xpids = fnmatch.filter(os.listdir(base_path), f"{prefix}*")
    filter_re = re.compile(r".*_[0-9]*$")
    return [xp for xp in all_xpids if filter_re.match(xp)]


def main():
    os.environ["OMP_NUM_THREADS"] = "1"

    # Virtual display (match your original pattern)
    display = None
    if sys.platform.startswith("linux"):
        try:
            import pyvirtualdisplay  # type: ignore

            display = pyvirtualdisplay.Display(visible=0, size=(1400, 900), color_depth=24)
            display.start()
        except Exception:
            display = None

    args = DotDict(vars(parse_args()))
    args.num_processes = min(int(args.num_processes), int(args.num_seeds))

    # env list
    if args.env_names.strip():
        env_names = [x.strip() for x in args.env_names.split(",") if x.strip()]
    else:
        env_names = _get_bipedal_env_names()

    # resolve xpids
    xpids = _resolve_xpids(args.base_path, args.xpid, args.prefix)
    if args.max_seeds is not None:
        xpids = xpids[: int(args.max_seeds)]

    if len(xpids) == 0:
        raise ValueError("No xpids matched.")

    os.makedirs(os.path.dirname(os.path.expanduser(args.wide_csv)) or ".", exist_ok=True)

    for xpid in xpids:
        # check checkpoint exists (like old eval.py)
        _xpid_dir, meta_json_path, checkpoint_path = _load_flags_and_checkpoint_paths(
            args.base_path, xpid, args.model_tar
        )
        if not os.path.exists(checkpoint_path):
            print(f"[Skip] No checkpoint: {checkpoint_path}")
            continue

        # wrapper kwargs default from CLI (will be overridden by xpid_flags if present)
        wrapper_kwargs: Dict[str, Any] = dict(
            frame_stack=int(args.frame_stack),
            grayscale=bool(args.grayscale),
            use_global_policy=bool(args.use_global_policy),
            use_global_critic=bool(args.use_global_critic),
            crop_frame=bool(args.crop_frame),
            num_action_repeat=int(args.num_action_repeat),
        )

        # load checkpoint ONCE for this xpid
        try:
            agent, xpid_flags, _ = load_agent_from_xpid_once(
                base_path=args.base_path,
                xpid=xpid,
                model_tar=args.model_tar,
                model_name=args.model_name,
                dummy_env_name=env_names[0],
                device=args.device,
                wrapper_kwargs=wrapper_kwargs,
            )
        except Exception as e:
            print(f"[Skip] Failed to load xpid={xpid}: {e}")
            continue

        # override wrapper kwargs from flags if they exist
        wrapper_kwargs = dict(
            frame_stack=int(
                getattr(xpid_flags, "frame_stack", wrapper_kwargs["frame_stack"])
            ),
            grayscale=bool(getattr(xpid_flags, "grayscale", wrapper_kwargs["grayscale"])),
            use_global_policy=bool(
                getattr(
                    xpid_flags, "use_global_policy", wrapper_kwargs["use_global_policy"]
                )
            ),
            use_global_critic=bool(
                getattr(
                    xpid_flags, "use_global_critic", wrapper_kwargs["use_global_critic"]
                )
            ),
            crop_frame=bool(
                getattr(xpid_flags, "crop_frame", wrapper_kwargs["crop_frame"])
            ),
            num_action_repeat=int(
                getattr(
                    xpid_flags, "num_action_repeat", wrapper_kwargs["num_action_repeat"]
                )
            ),
        )

        # build venvs (one per env) for this xpid
        venvs: Dict[str, Any] = {}
        for env_name in env_names:
            venvs[env_name] = _build_parallel_env(
                env_name=env_name,
                num_processes=int(args.num_processes),
                device=str(args.device),
                record_video=bool(args.record_video),
                **wrapper_kwargs,
            )

        # eval
        solved_rate: Dict[str, float] = {}
        mean_return: Dict[str, float] = {}
        for i, env_name in enumerate(env_names):
            sr, mr = evaluate_env_fixed_seeds(
                env_name=env_name,
                venv=venvs[env_name],
                agent=agent,
                num_processes=int(args.num_processes),
                seed_start=int(args.seed_start),
                num_seeds=int(args.num_seeds),
                episodes_per_seed=int(args.episodes_per_seed),
                device=str(args.device),
                deterministic=bool(args.deterministic),
                render_first_env=(bool(args.render) and i == 0),
            )
            solved_rate[env_name] = float(sr)
            mean_return[env_name] = float(mr)
            print(f"[Eval:{xpid}] {env_name}: solved_rate={sr:.4f}, mean_return={mr:.4f}")

        base_path_expanded = os.path.expandvars(os.path.expanduser(args.base_path))
        col_name = os.path.basename(base_path_expanded.rstrip("/"))

        # append to wide csv, column name = xpid (last dir name)
        append_xpid_results_to_wide_csv(
            out_csv=str(args.wide_csv),
            col_name=col_name,
            env_names=env_names,
            solved_rate=solved_rate,
            mean_return=mean_return,
        )
        print(f"[OK] Appended xpid '{xpid}' to: {args.wide_csv}")

        # cleanup venvs
        for v in venvs.values():
            v.close()

    if display:
        try:
            display.stop()
        except Exception:
            pass


if __name__ == "__main__":
    main()
