# -*- coding: utf-8 -*-
"""
minimal_eval_seed_stats.py

Goal (updated):
For each seed, save (num_blocks, shortest_path_length, passed) ONLY if valid,
where valid := shortest_path_length != 170.0

New requirements implemented:
1) Save pass index:
   - passed = 1 if the agent "solves" the episode, else 0
2) For each valid seed, save a render image named by seed index:
   - <screenshot_dir>/seed_<SEED>.png

CRITICAL CHANGE (your request):
We MUST save images using the vec-env method exactly like your snippet:
    venv.reset_agent()
    images = venv.get_images()
    save_images(images[:args.screenshot_batch_size], ...)

But since you want per-seed filenames, we do:
    save_images(images[i:i+1], f"seed_{seed}.png", normalize=True, channels_first=False)

Notes:
- Keeps "dummy env uses same wrappers" invariant when loading the agent.
- Evaluates ONE episode per seed (in parallel batches).
- Screenshots are taken from the venv images (NOT raw env render / window grab).

Project dependencies:
- util.DotDict, util.str2bool, util.make_agent (+ optional util.is_discrete_actions, util.save_images)
- envs.wrappers.ParallelAdversarialVecEnv, VecMonitor, VecPreprocessImageWrapper, MultiGridFullyObsWrapper
- envs.registration.make as gym_make
"""

import os
import csv
import json
import argparse
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt

from envs.registration import make as gym_make
from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
    MultiGridFullyObsWrapper,
)

# ensure env registrations side-effects (same as your full script)
from envs.multigrid.adversarial import SeededGoalLastAdversarialEnv  # noqa: F401

try:
    import gymnasium as gym
except ImportError:
    import gym

from util import DotDict, str2bool, make_agent

# Optional utilities (present in your big eval script)
try:
    from util import is_discrete_actions  # type: ignore
except Exception:
    is_discrete_actions = None

try:
    from util import save_images  # type: ignore
except Exception:
    save_images = None


INVALID_SPL = 170.0


def parse_args():
    p = argparse.ArgumentParser(
        "Minimal per-seed stats: (num_blocks, shortest_path, passed) for valid seeds"
    )
    p.add_argument("--run_dir", type=str, required=True)
    p.add_argument("--env_name", type=str, required=True)
    p.add_argument("--checkpoint_name", type=str, default="model_20000.tar")
    p.add_argument(
        "--model_name", type=str, default="agent", choices=["agent", "adversary_agent"]
    )
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--deterministic", type=str2bool, nargs="?", const=True, default=True)

    p.add_argument("--seed_start", type=int, default=1)
    p.add_argument("--num_seeds", type=int, default=100)
    p.add_argument("--num_processes", type=int, default=10)

    # wrapper fallbacks (if absent in meta.json)
    p.add_argument("--frame_stack", type=int, default=1)
    p.add_argument("--grayscale", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument(
        "--use_global_policy", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument(
        "--use_global_critic", type=str2bool, nargs="?", const=True, default=False
    )
    p.add_argument("--crop_frame", type=str2bool, nargs="?", const=True, default=False)
    p.add_argument("--num_action_repeat", type=int, default=8)

    # outputs
    p.add_argument("--out_csv", type=str, default="./eval_results/seed_stats_valid.csv")
    p.add_argument(
        "--screenshot_dir",
        type=str,
        default="./eval_results/seed_screenshots",
        help="Directory to save seed screenshots: seed_<SEED>.png",
    )
    # keep this arg name to mirror your snippet
    p.add_argument(
        "--screenshot_batch_size",
        type=int,
        default=1,
        help="How many images to pass to save_images at once (per file). Usually 1.",
    )

    # keep these args to match your snippet signature, but they are unused here
    p.add_argument("--use_editor", type=str2bool, nargs="?", const=True, default=False)

    return p.parse_args()


def make_env(env_name: str, **kwargs):
    if env_name in ["BipedalWalker-v3", "BipedalWalkerHardcore-v3"]:
        env = gym.make(env_name)
    else:
        env = gym_make(env_name)

    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    if is_multigrid and kwargs.get("use_global_policy"):
        env = MultiGridFullyObsWrapper(env, is_adversarial=False)

    return env


def wrap_venv(venv, env_name: str, device: str):
    is_multigrid = env_name.startswith("MultiGrid") or env_name.startswith("MiniGrid")
    is_bipedal = env_name.startswith("BipedalWalker")

    obs_key = "image" if is_multigrid else None
    scale = 10.0 if is_multigrid else None
    transpose_order = None if is_bipedal else [2, 0, 1]

    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecPreprocessImageWrapper(
        venv=venv,
        obs_key=obs_key,
        transpose_order=transpose_order,
        scale=scale,
        device=device,
    )
    return venv


def build_parallel_env(env_name: str, num_processes: int, device: str, **kwargs):
    make_fns = []
    for _ in range(int(num_processes)):

        def _make(env_name_=env_name, kwargs_=kwargs):
            return make_env(env_name_, **kwargs_)

        make_fns.append(_make)

    venv = ParallelAdversarialVecEnv(make_fns, adversary=False, is_eval=True)
    venv = wrap_venv(venv, env_name=env_name, device=device)
    return venv


def _unwrap_to_base_venv(v):
    while hasattr(v, "venv"):
        v = v.venv
    return v


def load_agent_from_run_dir(
    run_dir: str,
    env_name: str,
    checkpoint_name: str,
    model_name: str,
    device: str,
    wrapper_kwargs: Dict[str, Any],
):
    run_dir = os.path.expandvars(os.path.expanduser(run_dir))
    meta_path = os.path.join(run_dir, "meta.json")
    ckpt_path = os.path.join(run_dir, checkpoint_name)

    with open(meta_path, "r") as f:
        xpid_flags = DotDict(json.load(f)["args"])

    # dummy env uses SAME wrappers (important)
    dummy_venv = build_parallel_env(
        env_name, num_processes=1, device=device, **wrapper_kwargs
    )
    agent = make_agent(name="agent", env=dummy_venv, args=xpid_flags, device=device)

    checkpoint = torch.load(ckpt_path, map_location="cpu")
    if isinstance(checkpoint, dict) and "runner_state_dict" in checkpoint:
        sd = checkpoint["runner_state_dict"]["agent_state_dict"][model_name]
        agent.algo.actor_critic.load_state_dict(sd)
    else:
        agent.algo.actor_critic.load_state_dict(checkpoint)

    dummy_venv.close()
    return agent, xpid_flags


def is_invalid_spl(x: Optional[float]) -> bool:
    if x is None:
        return False
    try:
        return float(x) == float(INVALID_SPL)
    except Exception:
        return False


def _get_solved_threshold(env_name: str) -> float:
    # Match your usual convention
    if env_name.startswith("BipedalWalker"):
        return 230.0
    return 0.0


def _infer_is_discrete(venv) -> bool:
    if is_discrete_actions is None:
        # safe default: multigrid is discrete; bipedal is continuous
        return not str(getattr(venv, "env_name", "")).startswith("BipedalWalker")
    try:
        return bool(is_discrete_actions(venv))
    except Exception:
        return True


def evaluate_one_episode_per_seed_batch(
    *,
    venv,
    agent,
    device: str,
    deterministic: bool,
    active_mask: List[bool],
) -> Tuple[List[float], List[int]]:
    """
    Run exactly 1 episode per active slot, in parallel.
    Returns:
      - returns_vec: length P, episode return (0.0 if never collected)
      - got_mask: length P, 1 if collected an episode return, else 0
    """
    P = len(active_mask)
    returns_vec: List[float] = [0.0 for _ in range(P)]
    got_episode: List[bool] = [False for _ in range(P)]

    obs = venv.reset()

    rh = torch.zeros(P, agent.algo.actor_critic.recurrent_hidden_state_size, device=device)
    if (
        agent.algo.actor_critic.is_recurrent
        and getattr(agent.algo.actor_critic, "rnn", None) is not None
    ):
        if getattr(agent.algo.actor_critic.rnn, "arch", "") == "lstm":
            rh = (rh, torch.zeros_like(rh))
    masks = torch.ones(P, 1, device=device)

    discrete = _infer_is_discrete(venv)

    while True:
        if all((not active_mask[i]) or got_episode[i] for i in range(P)):
            break

        with torch.no_grad():
            _, action, _, rh = agent.act(obs, rh, masks, deterministic=bool(deterministic))

        action_np = action.detach().cpu().numpy()
        if not discrete:
            action_np = agent.process_action(action_np)

        obs, _reward, done, infos = venv.step(action_np)

        masks = torch.tensor(
            [[0.0] if bool(d) else [1.0] for d in done], dtype=torch.float32, device=device
        )

        for i, info in enumerate(infos):
            if not active_mask[i] or got_episode[i]:
                continue
            if isinstance(info, dict) and "episode" in info:
                try:
                    returns_vec[i] = float(info["episode"]["r"])
                except Exception:
                    returns_vec[i] = 0.0
                got_episode[i] = True

                # reset hidden state for that slot
                if getattr(agent.algo.actor_critic, "is_recurrent", False):
                    try:
                        if isinstance(rh, tuple):
                            rh[0][i].zero_()
                            rh[1][i].zero_()
                        else:
                            rh[i].zero_()
                    except Exception:
                        pass

    got_mask = [1 if got_episode[i] else 0 for i in range(P)]
    return returns_vec, got_mask


def save_seed_images_via_venv(
    *,
    venv,
    seeds_vec: List[int],
    valid_slot: List[bool],
    screenshot_dir: str,
    screenshot_batch_size: int,
):
    """
    EXACTLY your requested approach (vec env):
        venv.reset_agent()
        images = venv.get_images()
        save_images(images[:args.screenshot_batch_size], ...)

    But we need per-seed filenames:
        <screenshot_dir>/seed_<SEED>.png

    So we call save_images on a 1-image slice: images[i:i+1]
    """
    if save_images is None:
        raise RuntimeError("util.save_images is not available, cannot save screenshots.")

    os.makedirs(screenshot_dir, exist_ok=True)

    # === your snippet ===
    venv.reset_agent()
    images = venv.get_images()

    # Per valid seed: save one file
    # images can be list or np.ndarray; slicing works for both.
    for i, ok in enumerate(valid_slot):
        if not ok:
            continue
        seed_i = int(seeds_vec[i])
        out_path = os.path.join(screenshot_dir, f"seed_{seed_i}.png")

        save_images(
            images[i : i + 1][: int(screenshot_batch_size)],
            out_path,
            normalize=True,
            channels_first=False,
        )

    plt.close()


def collect_seed_stats(
    *,
    venv,
    agent,
    env_name: str,
    device: str,
    deterministic: bool,
    seed_start: int,
    num_seeds: int,
    num_processes: int,
    out_csv: str,
    screenshot_dir: str,
    screenshot_batch_size: int,
):
    """
    For each seed:
      - reset with per-slot seeds
      - read:
          base.get_num_blocks()
          base.get_shortest_path_length()
      - if valid (spl != 170):
          - run 1 episode to get return -> passed = 1/0 via threshold
          - save screenshot for that seed via venv.get_images named by seed
          - write row to CSV
    """
    base = _unwrap_to_base_venv(venv)
    P = int(num_processes)
    seeds = list(range(int(seed_start), int(seed_start) + int(num_seeds)))
    solved_threshold = _get_solved_threshold(env_name)

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    os.makedirs(screenshot_dir, exist_ok=True)

    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(
            f,
            fieldnames=[
                "seed",
                "n_clutter_placed",
                "shortest_path_length",
                "passed",
                "episode_return",
            ],
        )
        w.writeheader()

        for batch_start in range(0, len(seeds), P):
            batch = seeds[batch_start : batch_start + P]
            active = [True] * len(batch) + [False] * (P - len(batch))
            seeds_vec = [int(s) for s in batch] + [0] * (P - len(batch))

            # IMPORTANT: reset with per-slot seeds
            venv.reset_with_seeds(seeds_vec)
            venv.reset()

            n_blocks_vec: List[Optional[int]] = [None] * P
            spl_vec: List[Optional[float]] = [None] * P

            # Pull per-slot stats if supported
            try:
                if hasattr(base, "get_num_blocks"):
                    nb = base.get_num_blocks()
                    if isinstance(nb, (list, tuple)) and len(nb) == P:
                        n_blocks_vec = [int(x) if x is not None else None for x in nb]
                if hasattr(base, "get_shortest_path_length"):
                    sp = base.get_shortest_path_length()
                    if isinstance(sp, (list, tuple)) and len(sp) == P:
                        spl_vec = [float(x) if x is not None else None for x in sp]
            except Exception:
                pass

            # Determine valid slots
            valid_slot = [False] * P
            for i in range(P):
                if not active[i]:
                    continue
                if is_invalid_spl(spl_vec[i]):
                    continue
                valid_slot[i] = True

            # Evaluate 1 episode per valid slot
            if any(valid_slot):
                returns_vec, _ = evaluate_one_episode_per_seed_batch(
                    venv=venv,
                    agent=agent,
                    device=device,
                    deterministic=deterministic,
                    active_mask=valid_slot,
                )
            else:
                returns_vec = [0.0] * P

            for i in range(P):
                if not valid_slot[i]:
                    continue

                seed_i = int(seeds_vec[i])
                r_i = float(returns_vec[i])
                passed_i = 1 if (r_i > float(solved_threshold)) else 0

                w.writerow(
                    {
                        "seed": seed_i,
                        "n_clutter_placed": n_blocks_vec[i],
                        "shortest_path_length": spl_vec[i],
                        "passed": int(passed_i),
                        "episode_return": float(r_i),
                    }
                )
                f.flush()


def main():
    args = DotDict(vars(parse_args()))
    os.environ["OMP_NUM_THREADS"] = "1"

    run_dir = os.path.expandvars(os.path.expanduser(args.run_dir))
    meta_path = os.path.join(run_dir, "meta.json")
    with open(meta_path, "r") as f:
        _ = DotDict(json.load(f)["args"])

    wrapper_defaults: Dict[str, Any] = dict(
        frame_stack=int(args.frame_stack),
        grayscale=bool(args.grayscale),
        use_global_policy=bool(args.use_global_policy),
        use_global_critic=bool(args.use_global_critic),
        crop_frame=bool(args.crop_frame),
        num_action_repeat=int(args.num_action_repeat),
    )

    # keep this to preserve the "dummy env uses same wrappers" invariant
    agent, xpid_flags = load_agent_from_run_dir(
        run_dir=args.run_dir,
        env_name=args.env_name,
        checkpoint_name=args.checkpoint_name,
        model_name=args.model_name,
        device=str(args.device),
        wrapper_kwargs=wrapper_defaults,
    )

    # actual wrapper kwargs: prefer xpid_flags, fall back to CLI defaults
    wrapper_kwargs = dict(
        frame_stack=int(
            getattr(xpid_flags, "frame_stack", wrapper_defaults["frame_stack"])
        ),
        grayscale=bool(getattr(xpid_flags, "grayscale", wrapper_defaults["grayscale"])),
        use_global_policy=bool(
            getattr(xpid_flags, "use_global_policy", wrapper_defaults["use_global_policy"])
        ),
        use_global_critic=bool(
            getattr(xpid_flags, "use_global_critic", wrapper_defaults["use_global_critic"])
        ),
        crop_frame=bool(getattr(xpid_flags, "crop_frame", wrapper_defaults["crop_frame"])),
        num_action_repeat=int(
            getattr(xpid_flags, "num_action_repeat", wrapper_defaults["num_action_repeat"])
        ),
    )

    venv = build_parallel_env(
        env_name=args.env_name,
        num_processes=int(args.num_processes),
        device=str(args.device),
        **wrapper_kwargs,
    )

    collect_seed_stats(
        venv=venv,
        agent=agent,
        env_name=str(args.env_name),
        device=str(args.device),
        deterministic=bool(args.deterministic),
        seed_start=int(args.seed_start),
        num_seeds=int(args.num_seeds),
        num_processes=int(args.num_processes),
        out_csv=str(args.out_csv),
        screenshot_dir=str(args.screenshot_dir),
        screenshot_batch_size=int(args.screenshot_batch_size),
    )

    venv.close()
    print(f"[OK] wrote valid seed stats -> {args.out_csv}", flush=True)
    print(f"[OK] saved seed screenshots -> {args.screenshot_dir}", flush=True)


if __name__ == "__main__":
    main()
