# try:
#     from pyvirtualdisplay import Display

#     display = Display(visible=0, size=(1400, 900))
#     display.start()
# except Exception as e:
#     print("⚠️ Virtual display not started:", e)

import sys
import os
import time
import timeit
import logging
import csv
import ast
import numpy as np
from arguments import parser

import torch
import gym
import matplotlib as mpl
import matplotlib.pyplot as plt
from baselines.logger import HumanOutputFormat

display = None

# --- Optional Weights & Biases integration ---
try:
    import wandb
except ImportError:
    wandb = None

wandb_run = None

from envs.multigrid import *
from envs.multigrid.adversarial import *
from envs.bipedalwalker import *

# from envs.runners.adversarial_runner_nov import AdversarialRunner
from envs.runners.adversarial_runner_es import AdversarialRunner

from util import (
    make_agent,
    FileWriter,
    safe_checkpoint,
    create_parallel_env,
    make_plr_args,
    save_images,
    array_to_csv_append,
)


########helper by AED##############


def bw_array_to_params_helper(arr):
    """
    Convert a BipedalWalker initial-level array into the params dict
    required by bw_params_to_index(), ignoring the last element (seed).

    Expected array format:
      [rough, pit1, pit2, stump1, stump2, stair1, stair2, steps, seed]

    All except the last element must be convertible to float.
    """

    if len(arr) < 9:
        raise ValueError(f"Expected array of length 9 (8 params + seed), got {len(arr)}")

    # Convert strings → float; ignore final element (seed)
    rough = float(arr[0])
    pit1 = float(arr[1])
    pit2 = float(arr[2])
    stump1 = float(arr[3])
    stump2 = float(arr[4])
    stair1 = float(arr[5])
    stair2 = float(arr[6])
    steps = float(arr[7])
    # arr[8] = seed (ignored)

    # Build params dict matching bw_params_to_index() expectation
    params = {
        "roughness": rough,
        "pit1": pit1,
        "pit2": pit2,  # optional: included to avoid missing keys
        "stump1": stump1,
        "stump2": stump2,
        "stair1": stair1,
        "stair2": stair2,
        "steps": int(steps),
    }

    return params


#####################################
from eval import Evaluator


# ----------------------------- Utils -----------------------------
def _sanitize_csv(path: str):
    """Remove NULs and trim any partial last line, in-place."""
    if not os.path.exists(path):
        return
    with open(path, "rb") as f:
        data = f.read()
    changed = False
    if b"\x00" in data:
        data = data.replace(b"\x00", b"")
        changed = True
    if data and not data.endswith(b"\n"):
        last = data.rfind(b"\n")
        if last == -1:
            data = b""
        else:
            data = data[: last + 1]
        changed = True
    if changed:
        with open(path, "wb") as f:
            f.write(data)


def _ensure_eval_csv(eval_csv_path: str, fieldnames):
    os.makedirs(os.path.dirname(eval_csv_path), exist_ok=True)
    need_header = (not os.path.exists(eval_csv_path)) or (
        os.path.getsize(eval_csv_path) == 0
    )
    if need_header:
        with open(eval_csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()


def _append_eval_row(eval_csv_path: str, fieldnames, row: dict):
    with open(eval_csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        clean = {k: row.get(k, None) for k in fieldnames}
        writer.writerow(clean)


def _inject_new_algo_args(args):
    """
    Placeholder for curriculum-related argument injection.

    Currently a no-op to avoid adding new attributes or logging fields.
    """
    return


# --------- Replay buffer (LevelStore) helpers ---------
def _get_level_store_from_runner(train_runner):
    """
    Best-effort helper to locate the LevelStore (replay buffer)
    inside the runner / agents.
    """
    # Direct attribute on runner
    if hasattr(train_runner, "level_store"):
        return train_runner.level_store

    # Common pattern: runner.level_sampler.level_store
    sampler = getattr(train_runner, "level_sampler", None)
    if sampler is not None and hasattr(sampler, "level_store"):
        return sampler.level_store

    # Sometimes stored on plr_args
    if hasattr(train_runner, "plr_args"):
        plr_args = train_runner.plr_args
        if hasattr(plr_args, "level_store"):
            return plr_args.level_store

    # Via agent.algo.level_sampler.level_store
    if hasattr(train_runner, "agents"):
        agent = train_runner.agents.get("agent", None)
        if agent is not None:
            algo = getattr(agent, "algo", None)
            if algo is not None:
                sampler = getattr(algo, "level_sampler", None)
                if sampler is not None and hasattr(sampler, "level_store"):
                    return sampler.level_store

    # Fallback: not found
    return None


def dump_replay_buffer_snapshot(
    train_runner,
    current_steps: int | None = None,
):

    level_store = _get_level_store_from_runner(train_runner)
    if level_store is None:
        print("[replay_dump] LevelStore not found on runner.")
        return

    seeds = sorted(level_store.seed2level.keys())
    num_levels = len(seeds)
    step_value = -1 if current_steps is None else int(current_steps)

    sample_counts = getattr(train_runner, "level_sample_counts", {})
    dropped_seeds = getattr(train_runner, "dropped_seeds", set())

    print("\n================ REPLAY BUFFER SNAPSHOT ================")
    print(f"[replay_dump] steps={step_value}, num_levels={num_levels}")
    for buffer_idx, seed in enumerate(seeds):
        try:
            level = level_store.get_level(seed)
        except Exception:
            level = level_store.seed2level[seed]

        parent_list = level_store.seed2parent.get(seed, [])
        edit_count = len(parent_list)

        if isinstance(sample_counts, dict):
            times_sampled = int(sample_counts.get(seed, 0))
        else:
            try:
                times_sampled = int(sample_counts[seed])
            except Exception:
                times_sampled = 0

        is_dropped = seed in dropped_seeds
        level_str = str(level)

    print("========================================================\n")


def periodic_dump_dropped_levels_csv(
    train_runner,
    dropped_csv_path: str,
    logged_dropped_seeds: set,
    current_steps: int | None = None,
):
    level_store = _get_level_store_from_runner(train_runner)
    if level_store is None:
        return

    dropped_seeds = getattr(train_runner, "dropped_seeds", None)
    if not dropped_seeds:
        return

    sample_counts = getattr(train_runner, "level_sample_counts", {})

    fieldnames = [
        "steps_dropped",
        "seed",
        "level",
        "edit_count",
        "times_sampled",
    ]
    os.makedirs(os.path.dirname(dropped_csv_path), exist_ok=True)

    need_header = (not os.path.exists(dropped_csv_path)) or (
        os.path.getsize(dropped_csv_path) == 0
    )

    steps_value = -1 if current_steps is None else int(current_steps)

    with open(dropped_csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if need_header:
            writer.writeheader()

        for seed in sorted(dropped_seeds):
            if seed in logged_dropped_seeds:
                continue

            try:
                level = level_store.get_level(seed)
            except Exception:
                level = level_store.seed2level.get(seed, None)

            parent_list = level_store.seed2parent.get(seed, [])
            edit_count = len(parent_list)

            if isinstance(sample_counts, dict):
                times_sampled = int(sample_counts.get(seed, 0))
            else:
                try:
                    times_sampled = int(sample_counts[seed])
                except Exception:
                    times_sampled = 0

            writer.writerow(
                {
                    "steps_dropped": steps_value,
                    "seed": int(seed),
                    "level": str(level),
                    "edit_count": edit_count,
                    "times_sampled": times_sampled,
                }
            )
            logged_dropped_seeds.add(seed)


def periodic_dump_added_levels_csv(
    train_runner,
    added_csv_path: str,
    known_added_seeds: set,
    current_steps: int | None = None,
):

    level_store = _get_level_store_from_runner(train_runner)
    if level_store is None:
        return

    all_seeds = sorted(level_store.seed2level.keys())
    new_seeds = [s for s in all_seeds if s not in known_added_seeds]
    if not new_seeds:
        return

    sample_counts = getattr(train_runner, "level_sample_counts", {})

    fieldnames = [
        "steps_added",
        "seed",
        "level",
        "edit_count",
        "times_sampled",
        "index_0",
        "index_1",
        "index_2",
        "index_3",
        "index_4",
    ]
    os.makedirs(os.path.dirname(added_csv_path), exist_ok=True)

    need_header = (not os.path.exists(added_csv_path)) or (
        os.path.getsize(added_csv_path) == 0
    )

    steps_value = -1 if current_steps is None else int(current_steps)

    with open(added_csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if need_header:
            writer.writeheader()

        for seed in new_seeds:
            try:
                level = level_store.get_level(seed)
            except Exception:
                level = level_store.seed2level.get(seed, None)

            parent_list = level_store.seed2parent.get(seed, [])
            edit_count = len(parent_list)

            if isinstance(sample_counts, dict):
                times_sampled = int(sample_counts.get(seed, 0))
            else:
                try:
                    times_sampled = int(sample_counts[seed])
                except Exception:
                    times_sampled = 0

            writer.writerow(
                {
                    "steps_added": steps_value,
                    "seed": int(seed),
                    "level": str(level),
                    "edit_count": edit_count,
                    "times_sampled": times_sampled,
                }
            )
            known_added_seeds.add(seed)


def periodic_dump_buffer_state_csv(
    train_runner,
    buffer_csv_path: str,
    current_steps: int | None = None,
):

    level_store = _get_level_store_from_runner(train_runner)
    if level_store is None:
        return

    seeds = sorted(level_store.seed2level.keys())
    if not seeds:
        return

    sample_counts = getattr(train_runner, "level_sample_counts", {})
    dropped_seeds = getattr(train_runner, "dropped_seeds", set())

    fieldnames = [
        "steps",
        "buffer_index",
        "seed",
        "level",
        "edit_count",
        "times_sampled",
        "dropped",
    ]
    os.makedirs(os.path.dirname(buffer_csv_path), exist_ok=True)

    need_header = (not os.path.exists(buffer_csv_path)) or (
        os.path.getsize(buffer_csv_path) == 0
    )

    steps_value = -1 if current_steps is None else int(current_steps)

    with open(buffer_csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if need_header:
            writer.writeheader()

        for buffer_idx, seed in enumerate(seeds):
            try:
                level = level_store.get_level(seed)
            except Exception:
                level = level_store.seed2level.get(seed, None)

            parent_list = level_store.seed2parent.get(seed, [])
            edit_count = len(parent_list)

            if isinstance(sample_counts, dict):
                times_sampled = int(sample_counts.get(seed, 0))
            else:
                try:
                    times_sampled = int(sample_counts[seed])
                except Exception:
                    times_sampled = 0

            is_dropped = seed in dropped_seeds

            # idx0, idx1, idx2, idx3, idx4 = train_runner.seed2partition[seed]

            writer.writerow(
                {
                    "steps": steps_value,
                    "buffer_index": buffer_idx,
                    "seed": int(seed),
                    "level": str(level),
                    "edit_count": edit_count,
                    "times_sampled": times_sampled,
                    "dropped": bool(is_dropped),
                }
            )


def periodic_dump_partition_overview_csv(
    train_runner,
    partition_csv_path: str,
    current_steps: int,
):

    pm = getattr(train_runner, "partition_manager", None)
    if pm is None:
        return

    flush = getattr(train_runner, "_flush_partition_events", None)
    if flush is not None:
        try:
            flush(force=True)
        except TypeError:
            flush()

    try:
        stats_list = pm.dump_stats_bfs(start=(0, 0, 0, 0, 0), max_depth=None)
    except Exception as e:
        print(f"[partition_overview] dump_stats_bfs error: {e}")
        return

    if not stats_list:
        return

    rows = []
    for entry in stats_list:
        idx = entry.get("index")
        if idx is None:
            continue
        row = {
            "steps": int(current_steps),
            "record_type": "overview",
            "index_0": int(idx[0]),
            "index_1": int(idx[1]),
            "index_2": int(idx[2]),
            "index_3": int(idx[3]),
            "index_4": int(idx[4]),
            "depth": int(entry.get("depth", 0)),
            "total_levels": int(entry.get("total_levels", 0)),
            "success": int(entry.get("success", 0)),
            "fail": int(entry.get("fail", 0)),
            "resolved_visits": int(entry.get("resolved_visits", 0)),
            "p_hat": float(entry.get("p_hat", 0.0)),
            "uncertainty": float(entry.get("uncertainty", 0.0)),
        }
        rows.append(row)

    if not rows:
        return

    array_to_csv_append(partition_csv_path, rows, append=True)


# # ----------------------------- Main -----------------------------
# if __name__ == "__main__":
#     os.environ["OMP_NUM_THREADS"] = "1"
#     args = parser.parse_args()

#     # === Inject/derive our new algorithm arguments (currently no-op) ===
#     _inject_new_algo_args(args)

#     # === Configure logging ==
#     if args.xpid is None:
#         args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
#     log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))

#     os.makedirs(log_dir, exist_ok=True)
#     run_dir = os.path.join(log_dir, args.xpid)
#     os.makedirs(run_dir, exist_ok=True)

#     print(f"[logging] log_dir={log_dir}")
#     print(f"[logging] run_dir={run_dir}")

#     _sanitize_csv(os.path.join(run_dir, "logs.csv"))
#     _sanitize_csv(os.path.join(run_dir, "fields.csv"))

#     # dropped / added / buffer levels CSV path
#     dropped_levels_csv_path = os.path.join(run_dir, "dropped_levels.csv")
#     added_levels_csv_path = os.path.join(run_dir, "added_levels.csv")
#     buffer_levels_csv_path = os.path.join(run_dir, "buffer_snapshots.csv")

#     args.dropped_levels_csv_path = dropped_levels_csv_path
#     args.added_levels_csv_path = added_levels_csv_path

#     partition_bandit_csv_path = os.path.join(run_dir, "partition_bandit.csv")
#     occational_partition_bandit_csv_path = os.path.join(
#         run_dir, "occational_partition_bandit.csv"
#     )

#     _sanitize_csv(partition_bandit_csv_path)
#     _sanitize_csv(occational_partition_bandit_csv_path)

#     if getattr(args, "use_plr", False):

#         default_updates_interval = 10

#         if not hasattr(args, "partition_log_interval"):
#             args.partition_log_interval = 10

#         if not hasattr(args, "partition_flush_interval"):
#             args.partition_flush_interval = 10

#         args.partition_log_csv = occational_partition_bandit_csv_path

#     # === Initialize FileWriter (unchanged behavior) ===
#     filewriter = FileWriter(xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir)
#     screenshot_dir = os.path.join(log_dir, args.xpid, "screenshots")
#     if not os.path.exists(screenshot_dir):
#         os.makedirs(screenshot_dir, exist_ok=True)

#     # === Initialize wandb ===
#     if wandb is not None:
#         project = getattr(args, "env_name", "default")
#         entity = getattr(args, "wandb_entity", None)
#         # Use the last segment of log_dir as stable run_id
#         run_id = args.log_dir.rstrip("/").split("/")[-1]
#         run_name = run_id

#         wandb_kwargs = dict(
#             project=project,
#             name=run_name,
#             id=run_id,
#             resume="allow",
#             config=args.__dict__,
#         )
#         if entity is not None:
#             wandb_kwargs["entity"] = entity

#         wandb_run = wandb.init(**wandb_kwargs)

#     def log_stats(stats):
#         # Log to local CSV logs (unchanged behavior)
#         filewriter.log(stats)

#         # Mirror to wandb so every CSV field becomes a wandb metric
#         if wandb_run is not None:
#             wandb.log(stats)

#         # Optional human-readable console output
#         if args.verbose:
#             HumanOutputFormat(sys.stdout).writekvs(stats)

#     if args.verbose:
#         logging.getLogger().setLevel(logging.INFO)
#     else:
#         logging.disable(logging.CRITICAL)

#     # === Determine device ====
#     args.cuda = not args.no_cuda and torch.cuda.is_available()
#     device = torch.device("cuda:0" if args.cuda else "cpu")
#     if "cuda" in device.type:
#         torch.backends.cudnn.benchmark = True
#         print("Using CUDA\n")

#     # === Create parallel envs ===
#     venv, ued_venv = create_parallel_env(args)

#     is_training_env = args.ued_algo in ["paired", "flexible_paired", "minimax"]
#     is_paired = args.ued_algo in ["paired", "flexible_paired"]

#     agent = make_agent(name="agent", env=venv, args=args, device=device)
#     adversary_agent, adversary_env = None, None
#     if is_paired or args.use_accel_paired:
#         adversary_agent = make_agent(
#             name="adversary_agent", env=venv, args=args, device=device
#         )
#     if is_training_env:
#         adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
#     if (
#         args.ued_algo == "domain_randomization"
#         and args.use_plr
#         and not args.use_reset_random_dr
#     ):
#         adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
#         adversary_env.random()

#     # === Create runner ===
#     plr_args = None
#     if args.use_plr:
#         plr_args = make_plr_args(args, venv.observation_space, venv.action_space)
#     train_runner = AdversarialRunner(
#         args=args,
#         venv=venv,
#         agent=agent,
#         ued_venv=ued_venv,
#         adversary_agent=adversary_agent,
#         adversary_env=adversary_env,
#         flexible_protagonist=False,
#         train=True,
#         plr_args=plr_args,
#         device=device,
#     )

#     # === Configure checkpointing ===
#     timer = timeit.default_timer
#     initial_update_count = 0
#     last_logged_update_at_restart = -1
#     checkpoint_path = os.path.expandvars(
#         os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))
#     )
#     if args.xpid_finetune:
#         model_fname = f"{args.model_finetune}.tar"
#         base_checkpoint_path = os.path.expandvars(
#             os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid_finetune, model_fname))
#         )

#     def checkpoint(index=None):
#         if args.disable_checkpoint:
#             return
#         safe_checkpoint(
#             {"runner_state_dict": train_runner.state_dict()},
#             checkpoint_path,
#             index=index,
#             archive_interval=args.archive_interval,
#         )
#         logging.info("Saved checkpoint to %s", checkpoint_path)

#     # === Load checkpoint ===
#     if args.checkpoint and os.path.exists(checkpoint_path):
#         checkpoint_states = torch.load(
#             checkpoint_path, map_location=lambda storage, loc: storage
#         )
#         last_logged_update_at_restart = filewriter.latest_tick()
#         train_runner.load_state_dict(checkpoint_states["runner_state_dict"])
#         initial_update_count = train_runner.num_updates
#         logging.info(f"Resuming preempted job after {initial_update_count} updates\n")
#     elif args.xpid_finetune and not os.path.exists(checkpoint_path):
#         checkpoint_states = torch.load(base_checkpoint_path)
#         state_dict = checkpoint_states["runner_state_dict"]
#         agent_state_dict = state_dict.get("agent_state_dict")
#         optimizer_state_dict = state_dict.get("optimizer_state_dict")
#         train_runner.agents["agent"].algo.actor_critic.load_state_dict(
#             agent_state_dict["agent"]
#         )
#         train_runner.agents["agent"].algo.optimizer.load_state_dict(
#             optimizer_state_dict["agent"]
#         )

#     #######################
#     ### load partitions

#     # partitions = make_partitions("bipedalwalker")
#     #######################

#     # === Evaluator ===
#     evaluator = None
#     eval_keys = []
#     if args.test_env_names:
#         evaluator = Evaluator(
#             args.test_env_names.split(","),
#             num_processes=args.test_num_processes,
#             num_episodes=args.test_num_episodes,
#             frame_stack=args.frame_stack,
#             grayscale=args.grayscale,
#             num_action_repeat=args.num_action_repeat,
#             use_global_critic=args.use_global_critic,
#             use_global_policy=args.use_global_policy,
#             eval_csv_dir=run_dir,
#             device=device,
#         )

#         eval_keys = evaluator.get_stats_keys()

#     # === Eval CSV setup ===
#     run_dir_eval = os.path.join(log_dir, args.xpid)
#     eval_csv_path = os.path.join(run_dir_eval, "eval_tests.csv")
#     eval_fieldnames = ["update"] + list(eval_keys)
#     _ensure_eval_csv(eval_csv_path, eval_fieldnames)
#     print(f"[eval_csv] {eval_csv_path}")

#     # === Train ===
#     last_checkpoint_idx = getattr(train_runner, args.checkpoint_basis)
#     update_start_time = timer()
#     num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes

#     replay_print_interval_steps = 1000
#     buffer_dump_interval_steps = 1000

#     logged_dropped_seeds = set()
#     known_added_seeds = set()

#     for j in range(initial_update_count, num_updates):
#         stats = train_runner.run()

#         current_step = j
#         current_global_steps = (
#             (train_runner.num_updates + train_runner.total_num_edits)
#             * args.num_processes
#             * args.num_steps
#         )

#         if args.use_plr and (j % replay_print_interval_steps == 0):
#             dump_replay_buffer_snapshot(
#                 train_runner,
#                 current_steps=current_step,
#             )

#         if args.use_plr and (j % buffer_dump_interval_steps == 0):
#             periodic_dump_buffer_state_csv(
#                 train_runner,
#                 buffer_levels_csv_path,
#                 current_steps=current_step,
#             )

#         if train_runner.num_updates <= last_logged_update_at_restart:
#             continue

#         log = (j % args.log_interval == 0) or j == num_updates - 1
#         save_screenshot = args.screenshot_interval > 0 and (
#             j % args.screenshot_interval == 0
#         )

#         test_stats = {}
#         if log:
#             # Eval
#             if evaluator is not None and (
#                 j % args.test_interval == 0 or j == num_updates - 1
#             ):
#                 test_stats = evaluator.evaluate(train_runner.agents["agent"])
#                 stats.update(test_stats)
#                 if args.use_accel_paired:
#                     adv_test_stats = evaluator.evaluate(
#                         train_runner.agents["adversary_agent"]
#                     )
#                     curr_keys = list(adv_test_stats.keys())
#                     for curr_key in curr_keys:
#                         adv_test_stats[f"advagent_{curr_key}"] = adv_test_stats[curr_key]
#                         adv_test_stats.pop(curr_key, None)
#                     stats.update(adv_test_stats)
#                 # append eval row
#                 row = {"update": j}
#                 for k in eval_fieldnames:
#                     if k == "update":
#                         continue
#                     row[k] = test_stats.get(k, None)
#                 _append_eval_row(eval_csv_path, eval_fieldnames, row)
#             else:
#                 if evaluator is not None:
#                     stats.update({k: None for k in eval_keys})

#             update_end_time = timer()
#             num_incremental_updates = 1 if j == 0 else args.log_interval
#             sps = (
#                 num_incremental_updates
#                 * (args.num_processes * args.num_steps)
#                 / (update_end_time - update_start_time)
#             )
#             update_start_time = update_end_time
#             stats.update({"sps": sps})
#             stats.update(test_stats)
#             log_stats(stats)

#         checkpoint_idx = getattr(train_runner, args.checkpoint_basis)

#         if checkpoint_idx != last_checkpoint_idx:
#             is_last_update = j == num_updates - 1
#             if is_last_update or (
#                 train_runner.num_updates > 0
#                 and checkpoint_idx % args.checkpoint_interval == 0
#             ):
#                 checkpoint(checkpoint_idx)
#                 logging.info(f"\nSaved checkpoint after update {j}")
#                 logging.info(f"\nLast update: {is_last_update}")

#                 if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
#                     seeds_csv_path = os.path.join(
#                         run_dir, f"top_seeds_{checkpoint_idx}.csv"
#                     )
#                     num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
#                     if num_saved > 0:
#                         logging.info(f"Saved {num_saved} top seeds to {seeds_csv_path}")

#                 last_checkpoint_idx = checkpoint_idx

#             elif (
#                 train_runner.num_updates > 0
#                 and args.archive_interval > 0
#                 and checkpoint_idx % args.archive_interval == 0
#             ):
#                 checkpoint(checkpoint_idx)
#                 logging.info(f"\nArchived checkpoint after update {j}")

#                 if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
#                     seeds_csv_path = os.path.join(
#                         run_dir, f"top_seeds_archive_{checkpoint_idx}.csv"
#                     )
#                     num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
#                     if num_saved > 0:
#                         logging.info(f"Archived {num_saved} top seeds to {seeds_csv_path}")

#                 last_checkpoint_idx = checkpoint_idx

#         if save_screenshot:
#             level_info = train_runner.sampled_level_info
#             if args.env_name.startswith("BipedalWalker"):
#                 encodings = venv.get_level()
#             else:
#                 venv.reset_agent()
#                 images = venv.get_images()
#                 if args.use_editor and level_info:
#                     save_images(
#                         images[: args.screenshot_batch_size],
#                         os.path.join(
#                             screenshot_dir,
#                             f"update{j}-replay{level_info['level_replay']}-n_edits{level_info['num_edits'][0]}.png",
#                         ),
#                         normalize=True,
#                         channels_first=False,
#                     )
#                 else:
#                     save_images(
#                         images[: args.screenshot_batch_size],
#                         os.path.join(screenshot_dir, f"update{j}.png"),
#                         normalize=True,
#                         channels_first=False,
#                     )
#                 plt.close()

#     if args.use_plr:
#         last_steps = train_runner.num_updates

#         periodic_dump_buffer_state_csv(
#             train_runner,
#             buffer_levels_csv_path,
#             current_steps=last_steps,
#         )

#         # Last round: force flush partition events + BFS overview
#         last_global_steps = (
#             (train_runner.num_updates + train_runner.total_num_edits)
#             * args.num_processes
#             * args.num_steps
#         )

#     if evaluator is not None:
#         evaluator.close()
#     venv.close()

#     if display:
#         display.stop()

#     if wandb_run is not None:
#         wandb_run.finish()

# ----------------------------- Main -----------------------------
if __name__ == "__main__":
    os.environ["OMP_NUM_THREADS"] = "1"
    args = parser.parse_args()

    # === Inject/derive our new algorithm arguments (currently no-op) ===
    _inject_new_algo_args(args)

    # === Configure logging ==
    if args.xpid is None:
        args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))

    os.makedirs(log_dir, exist_ok=True)
    run_dir = os.path.join(log_dir, args.xpid)
    os.makedirs(run_dir, exist_ok=True)

    print(f"[logging] log_dir={log_dir}")
    print(f"[logging] run_dir={run_dir}")

    _sanitize_csv(os.path.join(run_dir, "logs.csv"))
    _sanitize_csv(os.path.join(run_dir, "fields.csv"))

    # === Initialize FileWriter (unchanged behavior) ===
    filewriter = FileWriter(xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir)
    screenshot_dir = os.path.join(log_dir, args.xpid, "screenshots")
    if not os.path.exists(screenshot_dir):
        os.makedirs(screenshot_dir, exist_ok=True)

    # Path for replay buffer CSV (always overwritten with latest snapshot)
    replay_csv_path = os.path.join(run_dir, "replay_buffer.csv")

    # === Initialize wandb ===
    # - "folder" = env_name  -> wandb project
    # - "job name" = log_dir -> wandb run name & id
    if wandb is not None:
        project = getattr(args, "env_name", "default")
        entity = getattr(args, "wandb_entity", None)
        run_id = args.log_dir.rstrip("/").split("/")[
            -1
        ]  # stable, so same log_dir -> same wandb run

        run_name = run_id

        wandb_kwargs = dict(
            project=project,
            name=run_name,
            id=run_id,
            resume="allow",
            config=args.__dict__,
        )
        if entity is not None:
            wandb_kwargs["entity"] = entity

        wandb_run = wandb.init(**wandb_kwargs)

    def log_stats(stats):
        # Log to local CSV logs (unchanged behavior)
        filewriter.log(stats)

        # Mirror to wandb so every CSV field becomes a wandb metric
        if wandb_run is not None:
            wandb.log(stats)

        # Optional human-readable console output
        if args.verbose:
            HumanOutputFormat(sys.stdout).writekvs(stats)

    if args.verbose:
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.disable(logging.CRITICAL)

    # === Determine device ====
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if "cuda" in device.type:
        torch.backends.cudnn.benchmark = True
        print("Using CUDA\n")

    # === Create parallel envs ===
    venv, ued_venv = create_parallel_env(args)

    is_training_env = args.ued_algo in ["paired", "flexible_paired", "minimax"]
    is_paired = args.ued_algo in ["paired", "flexible_paired"]

    agent = make_agent(name="agent", env=venv, args=args, device=device)
    adversary_agent, adversary_env = None, None
    if is_paired or args.use_accel_paired:
        adversary_agent = make_agent(
            name="adversary_agent", env=venv, args=args, device=device
        )
    if is_training_env:
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
    if (
        args.ued_algo == "domain_randomization"
        and args.use_plr
        and not args.use_reset_random_dr
    ):
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
        adversary_env.random()

    # === Create runner ===
    plr_args = None
    if args.use_plr:
        plr_args = make_plr_args(args, venv.observation_space, venv.action_space)
    train_runner = AdversarialRunner(
        args=args,
        venv=venv,
        agent=agent,
        ued_venv=ued_venv,
        adversary_agent=adversary_agent,
        adversary_env=adversary_env,
        flexible_protagonist=False,
        train=True,
        plr_args=plr_args,
        device=device,
    )

    # === Configure checkpointing ===
    timer = timeit.default_timer
    initial_update_count = 0
    last_logged_update_at_restart = -1
    checkpoint_path = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))
    )
    if args.xpid_finetune:
        model_fname = f"{args.model_finetune}.tar"
        base_checkpoint_path = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid_finetune, model_fname))
        )

    def checkpoint(index=None):
        if args.disable_checkpoint:
            return
        safe_checkpoint(
            {"runner_state_dict": train_runner.state_dict()},
            checkpoint_path,
            index=index,
            archive_interval=args.archive_interval,
        )
        logging.info("Saved checkpoint to %s", checkpoint_path)

    # === Load checkpoint ===
    if args.checkpoint and os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage
        )
        # last_logged_update_at_restart = filewriter.latest_tick()
        train_runner.load_state_dict(checkpoint_states["runner_state_dict"])
        initial_update_count = train_runner.num_updates
        logging.info(f"Resuming preempted job after {initial_update_count} updates\n")
    elif args.xpid_finetune and not os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(base_checkpoint_path)
        state_dict = checkpoint_states["runner_state_dict"]
        agent_state_dict = state_dict.get("agent_state_dict")
        optimizer_state_dict = state_dict.get("optimizer_state_dict")
        train_runner.agents["agent"].algo.actor_critic.load_state_dict(
            agent_state_dict["agent"]
        )
        train_runner.agents["agent"].algo.optimizer.load_state_dict(
            optimizer_state_dict["agent"]
        )

    # === Evaluator ===
    evaluator = None
    eval_keys = []
    if args.test_env_names:
        evaluator = Evaluator(
            args.test_env_names.split(","),
            num_processes=args.test_num_processes,
            num_episodes=args.test_num_episodes,
            frame_stack=args.frame_stack,
            grayscale=args.grayscale,
            num_action_repeat=args.num_action_repeat,
            use_global_critic=args.use_global_critic,
            use_global_policy=args.use_global_policy,
            eval_csv_dir=run_dir,
            device=device,
        )
        try:
            eval_keys = evaluator.get_stats_keys()
        except Exception:
            eval_keys = []

    # === Eval CSV setup ===
    run_dir_eval = os.path.join(log_dir, args.xpid)
    eval_csv_path = os.path.join(run_dir_eval, "eval_tests.csv")
    eval_fieldnames = ["update"] + list(eval_keys)
    _ensure_eval_csv(eval_csv_path, eval_fieldnames)
    print(f"[eval_csv] {eval_csv_path}")

    # === Train ===
    last_checkpoint_idx = getattr(train_runner, args.checkpoint_basis)
    update_start_time = timer()
    num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes

    # We want to dump the replay buffer every 1000 environment steps.
    # We'll use the "steps" key from stats and move this threshold forward.
    next_replay_dump_at = 1000

    for j in range(initial_update_count, num_updates):
        stats = train_runner.run()

        if train_runner.num_updates <= last_logged_update_at_restart:
            continue

        log = (j % args.log_interval == 0) or j == num_updates - 1
        save_screenshot = args.screenshot_interval > 0 and (
            j % args.screenshot_interval == 0
        )

        test_stats = {}

        if evaluator is not None and (
            train_runner.student_grad_updates > 0
            and train_runner.student_grad_updates % args.test_interval == 0
        ):
            test_stats = evaluator.evaluate(train_runner.agents["agent"])
            stats.update(test_stats)
            if args.use_accel_paired:
                adv_test_stats = evaluator.evaluate(train_runner.agents["adversary_agent"])
                curr_keys = list(adv_test_stats.keys())
                for curr_key in curr_keys:
                    adv_test_stats[f"advagent_{curr_key}"] = adv_test_stats[curr_key]
                    adv_test_stats.pop(curr_key, None)
                stats.update(adv_test_stats)
            # append eval row
            row = {"update": train_runner.student_grad_updates}
            for k in eval_fieldnames:
                if k == "update":
                    continue
                row[k] = test_stats.get(k, None)
            _append_eval_row(eval_csv_path, eval_fieldnames, row)
        else:
            if evaluator is not None:
                stats.update({k: None for k in eval_keys})
        if log:

            update_end_time = timer()
            num_incremental_updates = 1 if j == 0 else args.log_interval
            sps = (
                num_incremental_updates
                * (args.num_processes * args.num_steps)
                / (update_end_time - update_start_time)
            )
            update_start_time = update_end_time
            stats.update({"sps": sps})
            stats.update(test_stats)
            log_stats(stats)

        checkpoint_idx = getattr(train_runner, args.checkpoint_basis)

        if checkpoint_idx != last_checkpoint_idx:
            is_last_update = j == num_updates - 1
            if is_last_update or (
                train_runner.num_updates > 0
                and checkpoint_idx % args.checkpoint_interval == 0
            ):
                checkpoint(checkpoint_idx)
                logging.info(f"\nSaved checkpoint after update {j}")
                logging.info(f"\nLast update: {is_last_update}")

                if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
                    seeds_csv_path = os.path.join(
                        run_dir, f"top_seeds_{checkpoint_idx}.csv"
                    )
                    num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
                    if num_saved > 0:
                        logging.info(f"Saved {num_saved} top seeds to {seeds_csv_path}")

            elif (
                train_runner.num_updates > 0
                and args.archive_interval > 0
                and checkpoint_idx % args.archive_interval == 0
            ):
                checkpoint(checkpoint_idx)
                logging.info(f"\nArchived checkpoint after update {j}")

                if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
                    seeds_csv_path = os.path.join(
                        run_dir, f"top_seeds_archive_{checkpoint_idx}.csv"
                    )
                    num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
                    if num_saved > 0:
                        logging.info(f"Archived {num_saved} top seeds to {seeds_csv_path}")

        if save_screenshot:
            level_info = train_runner.sampled_level_info
            if args.env_name.startswith("BipedalWalker"):
                encodings = venv.get_level()
            else:
                venv.reset_agent()
                images = venv.get_images()
                if args.use_editor and level_info:
                    save_images(
                        images[: args.screenshot_batch_size],
                        os.path.join(
                            screenshot_dir,
                            f"update{j}-replay{level_info['level_replay']}-n_edits{level_info['num_edits'][0]}.png",
                        ),
                        normalize=True,
                        channels_first=False,
                    )
                else:
                    save_images(
                        images[: args.screenshot_batch_size],
                        os.path.join(screenshot_dir, f"update{j}.png"),
                        normalize=True,
                        channels_first=False,
                    )
                plt.close()

    if evaluator is not None:
        evaluator.close()
    venv.close()

    if display:
        display.stop()

    if wandb_run is not None:
        wandb_run.finish()
