# # -*- coding: utf-8 -*-
# try:
#     from pyvirtualdisplay import Display

#     display = Display(visible=0, size=(1400, 900))
#     display.start()
# except Exception as e:
#     print("⚠️ Virtual display not started:", e)

import sys
import os
import time
import timeit
import logging
import csv
from arguments import parser

import torch
import gym
import matplotlib as mpl
import matplotlib.pyplot as plt
from baselines.logger import HumanOutputFormat

display = None

try:
    import wandb
except ImportError:
    wandb = None

wandb_run = None

from envs.multigrid import *
from envs.multigrid.adversarial import *
from envs.bipedalwalker import *

from envs.runners.adversarial_runner_hybrid import AdversarialRunner

from util import (
    make_agent,
    FileWriter,
    safe_checkpoint,
    create_parallel_env,
    make_plr_args,
    save_images,
)
from eval import Evaluator


# ----------------------------- Utils -----------------------------
def _sanitize_csv(path: str):
    """Remove NULs and trim any partial last line, in-place."""
    if not os.path.exists(path):
        return
    with open(path, "rb") as f:
        data = f.read()
    changed = False
    if b"\x00" in data:
        data = data.replace(b"\x00", b"")
        changed = True
    if data and not data.endswith(b"\n"):
        last = data.rfind(b"\n")
        if last == -1:
            data = b""
        else:
            data = data[: last + 1]
        changed = True
    if changed:
        with open(path, "wb") as f:
            f.write(data)


def _ensure_eval_csv(eval_csv_path: str, fieldnames):
    os.makedirs(os.path.dirname(eval_csv_path), exist_ok=True)
    need_header = (not os.path.exists(eval_csv_path)) or (
        os.path.getsize(eval_csv_path) == 0
    )
    if need_header:
        with open(eval_csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()


def _append_eval_row(eval_csv_path: str, fieldnames, row: dict):
    with open(eval_csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        clean = {k: row.get(k, None) for k in fieldnames}
        writer.writerow(clean)


def _inject_new_algo_args(args):
    """
    Placeholder for curriculum-related argument injection.

    Currently a no-op to avoid adding new attributes or logging fields.
    """
    return


# --------- Replay buffer (LevelStore) helpers ---------
def _get_level_store_from_runner(train_runner):
    """
    Best-effort helper to locate the LevelStore (replay buffer)
    inside the runner / agents.

    You may adjust this if your LevelStore lives somewhere else.
    """
    # Direct attribute on runner
    if hasattr(train_runner, "level_store"):
        return train_runner.level_store

    # Common pattern: runner.level_sampler.level_store
    sampler = getattr(train_runner, "level_sampler", None)
    if sampler is not None and hasattr(sampler, "level_store"):
        return sampler.level_store

    # Sometimes stored on plr_args
    if hasattr(train_runner, "plr_args"):
        plr_args = train_runner.plr_args
        if hasattr(plr_args, "level_store"):
            return plr_args.level_store

    # Via agent.algo.level_sampler.level_store
    if hasattr(train_runner, "agents"):
        agent = train_runner.agents.get("agent", None)
        if agent is not None:
            algo = getattr(agent, "algo", None)
            if algo is not None:
                sampler = getattr(algo, "level_sampler", None)
                if sampler is not None and hasattr(sampler, "level_store"):
                    return sampler.level_store

    # Fallback: not found
    return None


def dump_replay_buffer_snapshot(
    train_runner,
    replay_csv_path: str,
    current_steps: int | None = None,
):
    """
    Dump the current replay buffer (LevelStore) into replay_buffer.csv.

    Columns:
        steps       : env steps at which this snapshot is taken
                      (taken from stats["steps"])
        num_levels  : total number of levels currently in the store
        seed        : level's seed/index
        edit_count  : how many edits this level has undergone
                      (length of seed2parent[seed])
        level       : string representation of the level parameters
                      (via LevelStore.get_level(seed))
    """
    level_store = _get_level_store_from_runner(train_runner)
    if level_store is None:
        # Nothing to dump
        return

    # Collect all seeds (keys) in a stable order
    seeds = sorted(level_store.seed2level.keys())
    num_levels = len(seeds)

    # Added "steps" column so we know when this snapshot was dumped
    fieldnames = ["steps", "num_levels", "seed", "edit_count", "level"]

    # Fallback: if current_steps is None, store -1
    step_value = -1 if current_steps is None else int(current_steps)

    with open(replay_csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

        for seed in seeds:
            # Decode level using LevelStore.get_level (handles numpy buffer logic)
            try:
                level = level_store.get_level(seed)
            except Exception:
                # Fallback: raw stored object if get_level fails
                level = level_store.seed2level[seed]

            # Edit count = length of parent chain
            parent_list = level_store.seed2parent.get(seed, [])
            edit_count = len(parent_list)

            # Turn level into a neat one-line string
            # (env-specific; for parameter vectors / tuples / strings this is clean)
            level_str = str(level)

            writer.writerow(
                {
                    "steps": step_value,
                    "num_levels": num_levels,
                    "seed": seed,
                    "edit_count": edit_count,
                    "level": level_str,
                }
            )


# ----------------------------- Main -----------------------------
if __name__ == "__main__":
    os.environ["OMP_NUM_THREADS"] = "1"
    args = parser.parse_args()

    # === Inject/derive our new algorithm arguments (currently no-op) ===
    _inject_new_algo_args(args)

    # === Configure logging ==
    if args.xpid is None:
        args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))

    os.makedirs(log_dir, exist_ok=True)
    run_dir = os.path.join(log_dir, args.xpid)
    os.makedirs(run_dir, exist_ok=True)

    print(f"[logging] log_dir={log_dir}")
    print(f"[logging] run_dir={run_dir}")

    _sanitize_csv(os.path.join(run_dir, "logs.csv"))
    _sanitize_csv(os.path.join(run_dir, "fields.csv"))

    # === Initialize FileWriter (unchanged behavior) ===
    filewriter = FileWriter(xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir)
    screenshot_dir = os.path.join(log_dir, args.xpid, "screenshots")
    if not os.path.exists(screenshot_dir):
        os.makedirs(screenshot_dir, exist_ok=True)

    # Path for replay buffer CSV (always overwritten with latest snapshot)
    replay_csv_path = os.path.join(run_dir, "replay_buffer.csv")

    # === Initialize wandb ===
    # - "folder" = env_name  -> wandb project
    # - "job name" = log_dir -> wandb run name & id
    if wandb is not None:
        project = getattr(args, "env_name", "default")
        entity = getattr(args, "wandb_entity", None)
        run_id = args.log_dir.rstrip("/").split("/")[
            -1
        ]  # stable, so same log_dir -> same wandb run

        run_name = run_id

        wandb_kwargs = dict(
            project=project,
            name=run_name,
            id=run_id,
            resume="allow",
            config=args.__dict__,
        )
        if entity is not None:
            wandb_kwargs["entity"] = entity

        wandb_run = wandb.init(**wandb_kwargs)

    def log_stats(stats):
        # Log to local CSV logs (unchanged behavior)
        filewriter.log(stats)

        # Mirror to wandb so every CSV field becomes a wandb metric
        if wandb_run is not None:
            wandb.log(stats)

        # Optional human-readable console output
        if args.verbose:
            HumanOutputFormat(sys.stdout).writekvs(stats)

    if args.verbose:
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.disable(logging.CRITICAL)

    # === Determine device ====
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if "cuda" in device.type:
        torch.backends.cudnn.benchmark = True
        print("Using CUDA\n")

    # === Create parallel envs ===
    venv, ued_venv = create_parallel_env(args)

    is_training_env = args.ued_algo in ["paired", "flexible_paired", "minimax"]
    is_paired = args.ued_algo in ["paired", "flexible_paired"]

    agent = make_agent(name="agent", env=venv, args=args, device=device)
    adversary_agent, adversary_env = None, None
    if is_paired or args.use_accel_paired:
        adversary_agent = make_agent(
            name="adversary_agent", env=venv, args=args, device=device
        )
    if is_training_env:
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
    if (
        args.ued_algo == "domain_randomization"
        and args.use_plr
        and not args.use_reset_random_dr
    ):
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
        adversary_env.random()

    # === Create runner ===
    plr_args = None
    if args.use_plr:
        plr_args = make_plr_args(args, venv.observation_space, venv.action_space)

    train_runner = AdversarialRunner(
        args=args,
        venv=venv,
        agent=agent,
        ued_venv=ued_venv,
        adversary_agent=adversary_agent,
        adversary_env=adversary_env,
        flexible_protagonist=False,
        train=True,
        plr_args=plr_args,
        device=device,
    )

    # === Configure checkpointing ===
    timer = timeit.default_timer
    initial_update_count = 0
    last_logged_update_at_restart = -1

    checkpoint_path = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model.tar"))
    )

    if args.xpid_finetune:
        model_fname = f"{args.model_finetune}.tar"
        base_checkpoint_path = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid_finetune, model_fname))
        )

    def checkpoint(index=None):
        if args.disable_checkpoint:
            return
        safe_checkpoint(
            {"runner_state_dict": train_runner.state_dict()},
            checkpoint_path,
            index=index,
            archive_interval=args.archive_interval,
        )
        logging.info("Saved checkpoint to %s", checkpoint_path)

    # === Load checkpoint ===
    if args.checkpoint and os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage
        )
        # last_logged_update_at_restart = filewriter.latest_tick()
        train_runner.load_state_dict(checkpoint_states["runner_state_dict"])
        initial_update_count = train_runner.num_updates
        logging.info(f"Resuming preempted job after {initial_update_count} updates\n")
    elif args.xpid_finetune and not os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(base_checkpoint_path)
        state_dict = checkpoint_states["runner_state_dict"]
        agent_state_dict = state_dict.get("agent_state_dict")
        optimizer_state_dict = state_dict.get("optimizer_state_dict")
        train_runner.agents["agent"].algo.actor_critic.load_state_dict(
            agent_state_dict["agent"]
        )
        train_runner.agents["agent"].algo.optimizer.load_state_dict(
            optimizer_state_dict["agent"]
        )

    # === Evaluator ===
    evaluator = None
    eval_keys = []
    if args.test_env_names:
        evaluator = Evaluator(
            args.test_env_names.split(","),
            num_processes=args.test_num_processes,
            num_episodes=args.test_num_episodes,
            frame_stack=args.frame_stack,
            grayscale=args.grayscale,
            num_action_repeat=args.num_action_repeat,
            use_global_critic=args.use_global_critic,
            use_global_policy=args.use_global_policy,
            eval_csv_dir=run_dir,
            device=device,
        )

        eval_keys = evaluator.get_stats_keys()
        # try:
        #     eval_keys = evaluator.get_stats_keys()
        # except Exception:
        #     eval_keys = []

    # === Eval CSV setup ===

    run_dir_eval = os.path.join(log_dir, args.xpid)
    eval_csv_path = os.path.join(run_dir_eval, "eval_tests.csv")
    eval_fieldnames = ["update"] + list(eval_keys)
    _ensure_eval_csv(eval_csv_path, eval_fieldnames)
    print(f"[eval_csv] {eval_csv_path}")

    # === Train ===
    last_checkpoint_idx = getattr(train_runner, args.checkpoint_basis)
    update_start_time = timer()
    num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes

    # We want to dump the replay buffer every 1000 environment steps.
    # We'll use the "steps" key from stats and move this threshold forward.
    next_replay_dump_at = 1000

    for j in range(initial_update_count, num_updates):

        if train_runner.warm_stage == True:
            stats = train_runner.run_warm()

        else:
            stats = train_runner.run_accel()

        # === Periodic replay buffer snapshot every 1000 steps ===
        current_steps = stats.get("steps", None)
        if current_steps is not None and current_steps >= next_replay_dump_at:
            dump_replay_buffer_snapshot(
                train_runner,
                replay_csv_path,
                current_steps=current_steps,
            )
            # Move threshold forward in multiples of 1000
            while next_replay_dump_at <= current_steps:
                next_replay_dump_at += 1000

        # === Perform logging ===
        if train_runner.num_updates <= last_logged_update_at_restart:
            continue

        log = (j % args.log_interval == 0) or j == num_updates - 1
        save_screenshot = args.screenshot_interval > 0 and (
            j % args.screenshot_interval == 0
        )

        test_stats = {}

        if evaluator is not None and (
            train_runner.student_grad_updates > 0
            and train_runner.student_grad_updates % args.test_interval == 0
        ):
            test_stats = evaluator.evaluate(train_runner.agents["agent"])
            stats.update(test_stats)
            if args.use_accel_paired:
                adv_test_stats = evaluator.evaluate(train_runner.agents["adversary_agent"])
                curr_keys = list(adv_test_stats.keys())
                for curr_key in curr_keys:
                    adv_test_stats[f"advagent_{curr_key}"] = adv_test_stats[curr_key]
                    adv_test_stats.pop(curr_key, None)
                stats.update(adv_test_stats)
            # append eval row
            row = {"update": train_runner.student_grad_updates}
            for k in eval_fieldnames:
                if k == "update":
                    continue
                row[k] = test_stats.get(k, None)
            _append_eval_row(eval_csv_path, eval_fieldnames, row)
        else:
            if evaluator is not None:
                stats.update({k: None for k in eval_keys})

        if log:

            update_end_time = timer()
            num_incremental_updates = 1 if j == 0 else args.log_interval
            sps = (
                num_incremental_updates
                * (args.num_processes * args.num_steps)
                / (update_end_time - update_start_time)
            )
            update_start_time = update_end_time
            stats.update({"sps": sps})
            stats.update(test_stats)
            log_stats(stats)

        checkpoint_idx = getattr(train_runner, args.checkpoint_basis)

        if checkpoint_idx != last_checkpoint_idx:
            is_last_update = j == num_updates - 1
            if is_last_update or (
                train_runner.num_updates > 0
                and checkpoint_idx % args.checkpoint_interval == 0
            ):
                checkpoint(checkpoint_idx)
                logging.info(f"\nSaved checkpoint after update {j}")
                logging.info(f"\nLast update: {is_last_update}")

                if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
                    seeds_csv_path = os.path.join(
                        run_dir, f"top_seeds_{checkpoint_idx}.csv"
                    )
                    num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
                    if num_saved > 0:
                        logging.info(f"Saved {num_saved} top seeds to {seeds_csv_path}")

            elif (
                train_runner.num_updates > 0
                and args.archive_interval > 0
                and checkpoint_idx % args.archive_interval == 0
            ):
                checkpoint(checkpoint_idx)
                logging.info(f"\nArchived checkpoint after update {j}")

                if args.use_plr and hasattr(train_runner, "save_top_k_seeds_csv"):
                    seeds_csv_path = os.path.join(
                        run_dir, f"top_seeds_archive_{checkpoint_idx}.csv"
                    )
                    num_saved = train_runner.save_top_k_seeds_csv(seeds_csv_path)
                    if num_saved > 0:
                        logging.info(f"Archived {num_saved} top seeds to {seeds_csv_path}")

        if save_screenshot:
            level_info = train_runner.sampled_level_info
            if args.env_name.startswith("BipedalWalker"):
                encodings = venv.get_level()
            else:
                venv.reset_agent()
                images = venv.get_images()
                if args.use_editor and level_info:
                    save_images(
                        images[: args.screenshot_batch_size],
                        os.path.join(
                            screenshot_dir,
                            f"update{j}-replay{level_info['level_replay']}-n_edits{level_info['num_edits'][0]}.png",
                        ),
                        normalize=True,
                        channels_first=False,
                    )
                else:
                    save_images(
                        images[: args.screenshot_batch_size],
                        os.path.join(screenshot_dir, f"update{j}.png"),
                        normalize=True,
                        channels_first=False,
                    )
                plt.close()

    if evaluator is not None:
        evaluator.close()
    venv.close()

    if display:
        display.stop()

    if wandb_run is not None:
        wandb_run.finish()
