import os
import csv
from typing import Any, Dict, List, Optional
from envs.registration import make as gym_make

import numpy as np
import torch

import json


try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = None

from envs.wrappers import (
    VecMonitor,
    VecPreprocessImageWrapper,
    ParallelAdversarialVecEnv,
)

try:
    from util import is_discrete_actions  # type: ignore
except Exception:
    is_discrete_actions = None


from envs.bipedalwalker.walker_test_envs import (
    SeededBipedalWalker,
    get_config,
)


import sys
import os
import time
import timeit
import logging
import csv
from arguments import parser

import torch
import gym
import matplotlib as mpl
import matplotlib.pyplot as plt
from baselines.logger import HumanOutputFormat

display = None


# from envs.multigrid import *
# from envs.multigrid.adversarial import *
from envs.bipedalwalker import *

# from envs.runners.new_adversarial_runner import AdversarialRunner
# from envs.runners.adversarial_runner_drop import AdversarialRunner
from envs.runners.adversarial_runner import AdversarialRunner


from util import (
    make_agent,
    # FileWriter,
    safe_checkpoint,
    create_parallel_env,
    make_plr_args,
    save_images,
)
from eval import Evaluator


def generate_param_csv(
    out_csv_path: str,
    n: int = 1000,
    rng_seed: int = 0,
) -> None:

    os.makedirs(os.path.dirname(out_csv_path) or ".", exist_ok=True)

    rng = np.random.default_rng(rng_seed)

    with open(out_csv_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "idx",
                "ground_roughness",
                "pit_gap",
                "stump_height",
                "stair_height",
                "stair_steps",
            ],
        )
        writer.writeheader()

        for i in range(n):
            ground_roughness = float(rng.uniform(*PARAM_RANGES_FULL["ground_roughness"]))
            pit_gap = float(rng.uniform(*PARAM_RANGES_FULL["pit_gap"]))
            stump_height = float(rng.uniform(*PARAM_RANGES_FULL["stump_height"]))
            stair_height = float(rng.uniform(*PARAM_RANGES_FULL["stair_height"]))
            stair_steps = int(
                rng.integers(
                    PARAM_RANGES_FULL["stair_steps"][0],
                    PARAM_RANGES_FULL["stair_steps"][1] + 1,
                )
            )

            writer.writerow(
                {
                    "idx": i,
                    "ground_roughness": ground_roughness,
                    "pit_gap": pit_gap,
                    "stump_height": stump_height,
                    "stair_height": stair_height,
                    "stair_steps": stair_steps,
                }
            )

    print(f"[OK] wrote {n} rows -> {out_csv_path}")


def load_param_csv(csv_path: str) -> List[Dict[str, Any]]:

    out: List[Dict[str, Any]] = []
    with open(csv_path, "r", newline="") as f:
        reader = csv.DictReader(f)
        for r in reader:
            out.append(
                {
                    "idx": int(r["idx"]),
                    "ground_roughness": float(r["ground_roughness"]),
                    "pit_gap": float(r["pit_gap"]),
                    "stump_height": float(r["stump_height"]),
                    "stair_height": float(r["stair_height"]),
                    "stair_steps": int(r["stair_steps"]),
                }
            )
    return out


def _build_parallel_env_bipedal_full(
    param_batch: List[Dict[str, Any]],
    device: str,
):
    make_fns = []

    for p in param_batch:

        def _make(p_=p):
            return gym_make(
                "BipedalWalker-Full-v0",
                ground_roughness=p_["ground_roughness"],
                pit_gap=[p_["pit_gap"], p_["pit_gap"]],
                stump_height=[p_["stump_height"], p_["stump_height"]],
                stair_height=[p_["stair_height"], p_["stair_height"]],
                stair_steps=[p_["stair_steps"]],
            )

        make_fns.append(_make)

    venv = ParallelAdversarialVecEnv(make_fns, adversary=False, is_eval=True)

    # bipedal: transpose_order=None
    venv = VecMonitor(venv=venv, filename=None, keep_buf=100)
    venv = VecPreprocessImageWrapper(
        venv=venv,
        obs_key=None,
        transpose_order=None,
        scale=None,
        device=device,
    )
    return venv


def _zero_recurrent_hidden(rh, i: int) -> None:
    if isinstance(rh, tuple):
        rh[0][i].zero_()
        rh[1][i].zero_()
    else:
        rh[i].zero_()


def _evaluate_bw_full_param_csv(
    *,
    agent,
    csv_path: str,
    num_processes: int,
    device: str,
    episodes_per_env: int,
    solved_threshold: float,
    deterministic: bool = True,
    show_progress: bool = False,
    accumulator: str = "mean",
):
    assert accumulator in ["mean"], "Only mean is implemented for now."

    params = load_param_csv(csv_path)
    N = len(params)
    assert N > 0
    assert num_processes >= 1
    assert episodes_per_env >= 1

    # pad to multiple of num_processes
    pad = (-N) % num_processes
    valid_N = N
    if pad > 0:
        last = params[-1]
        for k in range(pad):
            dup = dict(last)
            dup["idx"] = valid_N + k  # invalid
            params.append(dup)

    N_pad = len(params)
    batches = N_pad // num_processes

    device_t = torch.device(device)

    actor_critic = agent.algo.actor_critic
    hidden_size = actor_critic.recurrent_hidden_state_size
    is_recurrent = getattr(actor_critic, "is_recurrent", False) or getattr(
        agent, "is_recurrent", False
    )

    # per-env accumulator (handle episodes_per_env > 1)
    ret_sum = np.zeros(N_pad, dtype=np.float32)
    ret_cnt = np.zeros(N_pad, dtype=np.int32)

    disc_checker = is_discrete_actions

    # progress bar counts only VALID episodes
    pbar = None
    if show_progress and (tqdm is not None):
        pbar = tqdm(
            total=valid_N * episodes_per_env,
            desc="BW full-param eval (episodes)",
            dynamic_ncols=True,
            leave=True,
        )

    finished_valid_eps = 0
    solved_cnt = 0
    running_ret_sum = 0.0

    for b in range(batches):
        batch = params[b * num_processes : (b + 1) * num_processes]
        venv = _build_parallel_env_bipedal_full(batch, device=device)

        try:
            obs = venv.reset()

            rh = torch.zeros(num_processes, hidden_size, device=device_t)
            if is_recurrent:
                rh = (rh, torch.zeros_like(rh))

            masks = torch.ones(num_processes, 1, device=device_t)

            done_eps = [0 for _ in range(num_processes)]
            target_total = num_processes * episodes_per_env

            is_disc = False
            if disc_checker is not None:
                try:
                    is_disc = bool(disc_checker(venv))
                except Exception:
                    is_disc = False

            while sum(done_eps) < target_total:
                with torch.no_grad():
                    _, action, _, rh = agent.act(
                        obs,
                        rh,
                        masks,
                        deterministic=deterministic,
                    )

                action_np = action.cpu().numpy()
                if not is_disc:
                    action_np = agent.process_action(action_np)

                obs, reward, done, infos = venv.step(action_np)

                masks = torch.tensor(
                    [[0.0] if d else [1.0] for d in done],
                    dtype=torch.float32,
                    device=device_t,
                )

                for i, info in enumerate(infos):
                    if "episode" not in info:
                        continue
                    if done_eps[i] >= episodes_per_env:
                        continue

                    global_i = b * num_processes + i
                    is_valid = global_i < valid_N

                    ep_ret = float(info["episode"]["r"])

                    # accumulate per env
                    ret_sum[global_i] += ep_ret
                    ret_cnt[global_i] += 1
                    done_eps[i] += 1

                    # progress only for valid
                    if is_valid:
                        finished_valid_eps += 1
                        running_ret_sum += ep_ret

                        if pbar is not None:
                            pbar.update(1)
                            avg_ep_ret = running_ret_sum / max(1, finished_valid_eps)
                            pbar.set_postfix(
                                batch=f"{b+1}/{batches}",
                                avg_ep_ret=f"{avg_ep_ret:.1f}",
                                last_ret=f"{ep_ret:.1f}",
                            )

                    if is_recurrent:
                        _zero_recurrent_hidden(rh, i)

        finally:
            venv.close()

    if pbar is not None:
        pbar.close()

    # compute per-env mean return (drop padding)
    ret_sum = ret_sum[:valid_N]
    ret_cnt = ret_cnt[:valid_N]
    per_env_returns = ret_sum / np.maximum(ret_cnt, 1)

    mean_ret = float(np.mean(per_env_returns))
    std_ret = float(np.std(per_env_returns))
    solved_rate = float(np.mean(per_env_returns >= solved_threshold))

    return per_env_returns.tolist(), mean_ret, std_ret, solved_rate


PARAM_RANGES_FULL = {
    "ground_roughness": (0.0, 10.0),
    "pit_gap": (0.0, 10.0),
    "stump_height": (0.0, 5.0),
    "stair_height": (0.0, 5.0),
    "stair_steps": (1, 9),
}


# ----------------------------- Main -----------------------------
if __name__ == "__main__":
    os.environ["OMP_NUM_THREADS"] = "1"
    args = parser.parse_args()

    # === Inject/derive our new algorithm arguments (currently no-op) ===

    # === Configure logging ==
    if args.xpid is None:
        args.xpid = "lr-%s" % time.strftime("%Y%m%d-%H%M%S")
    log_dir = os.path.expandvars(os.path.expanduser(args.log_dir))

    # === Initialize FileWriter (unchanged behavior) ===
    # filewriter = FileWriter(xpid=args.xpid, xp_args=args.__dict__, rootdir=log_dir)
    # === Determine device ====
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if "cuda" in device.type:
        torch.backends.cudnn.benchmark = True
        print("Using CUDA\n")

    # === Create parallel envs ===
    venv, ued_venv = create_parallel_env(args)

    is_training_env = args.ued_algo in ["paired", "flexible_paired", "minimax"]
    is_paired = args.ued_algo in ["paired", "flexible_paired"]

    agent = make_agent(name="agent", env=venv, args=args, device=device)
    adversary_agent, adversary_env = None, None
    if is_paired or args.use_accel_paired:
        adversary_agent = make_agent(
            name="adversary_agent", env=venv, args=args, device=device
        )
    if is_training_env:
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
    if (
        args.ued_algo == "domain_randomization"
        and args.use_plr
        and not args.use_reset_random_dr
    ):
        adversary_env = make_agent(name="adversary_env", env=venv, args=args, device=device)
        adversary_env.random()

    # === Create runner ===
    plr_args = None
    if args.use_plr:
        plr_args = make_plr_args(args, venv.observation_space, venv.action_space)
    train_runner = AdversarialRunner(
        args=args,
        venv=venv,
        agent=agent,
        ued_venv=ued_venv,
        adversary_agent=adversary_agent,
        adversary_env=adversary_env,
        flexible_protagonist=False,
        train=True,
        plr_args=plr_args,
        device=device,
    )

    # === Configure checkpointing ===

    checkpoint_path = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (log_dir, args.xpid, "model_20000.tar"))
    )

    # === Load checkpoint ===
    if args.checkpoint and os.path.exists(checkpoint_path):
        checkpoint_states = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage
        )
        # last_logged_update_at_restart = filewriter.latest_tick()
        train_runner.load_state_dict(checkpoint_states["runner_state_dict"])
        initial_update_count = train_runner.num_updates
        logging.info(f"Resuming preempted job after {initial_update_count} updates\n")

    device = "cuda:0"
    num_processes = 16
    episodes_per_env = 1
    solved_threshold = 230.0

    gen_n = 1000
    gen_seed = 0

    csv_path = "evaluate_result/bw_full_params.csv"

    run_id = args.log_dir.rstrip("/").split("/")[-1]

    results_csv = f"evaluate_result/{run_id}_eval.csv"

    show_progress = True

    deterministic = True
    # ---------------------------------------

    agent = train_runner.agents["agent"]
    # ================================================================

    if os.path.isfile(csv_path) and os.path.getsize(csv_path) > 0:
        print(f"[skip] param csv already exists: {csv_path}")
    else:
        os.makedirs(os.path.dirname(csv_path), exist_ok=True)
        generate_param_csv(csv_path, n=gen_n, rng_seed=gen_seed)
        print(f"[ok] generated param csv: {csv_path}")

    assert agent is not None, "Please load your agent before running this script."

    per_env_returns, mean_ret, std_ret, solved_rate = _evaluate_bw_full_param_csv(
        agent=agent,
        csv_path=csv_path,
        num_processes=num_processes,
        device=device,
        episodes_per_env=episodes_per_env,
        solved_threshold=solved_threshold,
        deterministic=deterministic,
        show_progress=show_progress,
        accumulator="mean",
    )

    solved_flags = [1 if float(r) >= solved_threshold else 0 for r in per_env_returns]
    solved_count = int(sum(solved_flags))
    num_envs = int(len(solved_flags))

    os.makedirs(os.path.dirname(results_csv) or ".", exist_ok=True)
    with open(results_csv, "w", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "mean_return",
                "solved_rate",
                "solved_flags_json",
            ],
        )
        writer.writeheader()
        writer.writerow(
            {
                "mean_return": float(mean_ret),
                "solved_rate": float(solved_rate),
                "solved_flags_json": json.dumps(solved_flags),
            }
        )

    print(f"[OK] wrote eval summary -> {results_csv}")
    print(f"Mean return: {mean_ret:.3f}")
    print(f"Std return : {std_ret:.3f}")
    print(f"Solved rate: {solved_rate:.4f} (threshold={solved_threshold})")
# -*- coding: utf-8 -*-
