"""
FPVR-DQN Main Training Script
Pure FPVR exploration with experience replay (off-policy).

References:
- DQN framework (experience replay, 1-step TD updates)
"""
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import numpy as np
import torch
import os
import sys
import datetime
import json
import cv2
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import math

# Register ALE environments (optional dependency: allow `python main.py --help` without ale_py)
try:
    import ale_py  # type: ignore

    gym.register_envs(ale_py)
except Exception as e:
    ale_py = None
    if "--help" not in sys.argv and "-h" not in sys.argv:
        print(f"[Warning] ale_py import failed: {e}")
        print("[Warning] Atari envs may be unavailable until you install ale_py (and ROMs).")

# Ensure local modules are importable
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
if _THIS_DIR not in sys.path:
    sys.path.insert(0, _THIS_DIR)

from config import get_params
from agent import FPVRAgent
from atari_wrappers import NoopResetEnv, AtariMaxPoolWrapper

try:
    import imageio
except ImportError:
    imageio = None

try:
    import matplotlib
    matplotlib.use('Agg')  # Non-interactive backend
    import matplotlib.pyplot as plt
except (ImportError, AttributeError) as e:
    print(f"[Warning] matplotlib import failed: {e}")
    print("[Warning] matplotlib import failed; reconstruction visualization is disabled (reconstruction branch removed)")
    plt = None


def preprocessing(img):
    """Preprocess image: RGB -> grayscale 84x84."""
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
    return img


def stack_frames(stacked_frames, state, is_new_episode, num_stack):
    """Stack frames for observation."""
    frame = preprocessing(state)
    
    if is_new_episode:
        stacked_frames = np.stack([frame for _ in range(num_stack)], axis=0)
    else:
        stacked_frames = stacked_frames[1:, ...]
        stacked_frames = np.concatenate([stacked_frames, np.expand_dims(frame, axis=0)], axis=0)
    return stacked_frames


def _gym_make_strict(env_name: str, **kwargs):
    """
    Strict gym.make for Atari.
    We must *not* silently drop kwargs like sticky-actions or full_action_space, because that
    changes the experimental protocol.
    """
    try:
        return gym.make(env_name, **kwargs)
    except TypeError as e:
        msg = (
            f"[Env Error] Failed to create env={env_name!r} with required kwargs={kwargs}.\n"
            f"This code requires strict control over sticky actions and full action space.\n"
            f"Original error: {e}\n"
            f"Suggested fixes: upgrade `gymnasium`/`ale-py`, and ensure you're using the ALE v5 environments."
        )
        raise TypeError(msg) from e


def make_atari(
    env_name,
    max_episode_steps=None,
    use_max_pool=False,
    frame_skip=5,
    sticky_action_prob=0.25,
    full_action_space=True,
    noop_on_reset=True,
    noop_max=30,
):
    """
    Create Atari environment
    
    Args:
        env_name: Environment ID (e.g., 'ALE/MontezumaRevenge-v5')
        max_episode_steps: Maximum episode length in *agent steps* (number of env.step() calls)
        use_max_pool: If True, wrap with AtariMaxPoolWrapper for max pooling
        frame_skip: Frame skip value (used if use_max_pool=True)
    
    Returns:
        Wrapped or unwrapped environment
    """
    # Always enable DQN-style action repeat + max-over-last-2-frames preprocessing.
    # Ignore `use_max_pool` to avoid protocol drift.
    env = _gym_make_strict(
        env_name,
        frameskip=1,
        repeat_action_probability=float(sticky_action_prob),
        full_action_space=full_action_space,
    )
    env = AtariMaxPoolWrapper(env, frame_skip=frame_skip)
    # TimeLimit counts env.step() calls (agent steps), NOT underlying ALE frames.
    actual_max_steps = max_episode_steps

    if actual_max_steps is None:
        # DQN-style: 18,000 ALE frames (~5 minutes) per episode.
        actual_max_steps = int(math.ceil(18000 / max(1, int(frame_skip))))
    
    # Wrap with TimeLimit to ensure max_episode_steps is enforced
    # This is more reliable than just setting _max_episode_steps attribute
    if not isinstance(env, TimeLimit):
        env = TimeLimit(env, max_episode_steps=actual_max_steps)
    else:
        # If already wrapped with TimeLimit, update the max_episode_steps
        env._max_episode_steps = actual_max_steps

    # Apply random NOOP reset outermost so that NOOP steps count toward TimeLimit.
    if bool(noop_on_reset):
        env = NoopResetEnv(env, noop_max=int(noop_max))
    
    return env


def save_gif(frames, path, fps=30):
    """Save frames as GIF"""
    if imageio is None:
        print(f"[GIF Error] imageio is not available, cannot save GIF to {path}")
        return False
    if len(frames) == 0:
        print(f"[GIF Error] No frames to save to {path}")
        return False
    try:
        # Ensure frames are in the correct format (numpy arrays with uint8 dtype)
        processed_frames = []
        first_shape = None
        for i, frame in enumerate(frames):
            # Convert to numpy if needed
            if isinstance(frame, torch.Tensor):
                frame = frame.cpu().numpy()
            
            # Ensure it's a numpy array
            frame = np.asarray(frame)
            
            # Check and record first frame's shape
            if first_shape is None:
                first_shape = frame.shape
            else:
                # Verify all frames have the same shape
                if frame.shape != first_shape:
                    print(f"[GIF Error] Frame {i} has shape {frame.shape}, but first frame has shape {first_shape}. "
                          f"All frames must have the same shape!")
                    return False
            
            # Convert to uint8 if needed
            if frame.dtype != np.uint8:
                if frame.max() <= 1.0:
                    frame = (frame * 255).astype(np.uint8)
                else:
                    frame = frame.astype(np.uint8)
            
            processed_frames.append(frame)
        
        # Save GIF
        imageio.mimsave(path, processed_frames, fps=fps)
        print(f"[GIF] Successfully saved {len(processed_frames)} frames (shape: {first_shape}) to {path}")
        return True
    except Exception as e:
        print(f"[GIF Error] Failed to save GIF to {path}: {e}")
        import traceback
        traceback.print_exc()
        return False


#
# Note: FPVRNetwork reconstruction branch removed; no reconstruction visualization.
#


@torch.no_grad()
def evaluate_q_eps_greedy(
    q_net,
    env_name: str,
    *,
    eps: float,
    n_episodes: int,
    seed: int,
    frame_stack: int,
    max_episode_steps: int,
    use_max_pool: bool,
    frame_skip: int,
    sticky_action_prob: float,
    full_action_space: bool,
    noop_on_reset: bool,
    noop_max: int,
    device,
):
    """Run evaluation with Q-network only (ignore FPVR). Returns (mean_return, mean_len)."""
    if q_net is None:
        return None

    q_was_training = q_net.training
    q_net.eval()

    returns = []
    lengths = []
    rng = np.random.default_rng(seed)

    env = make_atari(
        env_name,
        max_episode_steps=max_episode_steps,
        use_max_pool=use_max_pool,
        frame_skip=frame_skip,
        sticky_action_prob=sticky_action_prob,
        full_action_space=full_action_space,
        noop_on_reset=noop_on_reset,
        noop_max=noop_max,
    )
    try:
        for ep in range(n_episodes):
            obs, _ = env.reset(seed=int(rng.integers(0, 2**31 - 1)))
            stacked = stack_frames(None, obs, is_new_episode=True, num_stack=frame_stack)
            done = False
            ep_ret = 0.0
            ep_len = 0

            while not done:
                if rng.random() < eps:
                    action = int(env.action_space.sample())
                else:
                    state_t = torch.from_numpy(stacked[None]).to(device)
                    if state_t.dtype != torch.uint8:
                        state_t = state_t.to(torch.uint8)
                    q_vals = q_net(state_t)
                    action = int(torch.argmax(q_vals, dim=1).item())

                step_result = env.step(action)
                if len(step_result) == 5:
                    next_obs, reward, terminated, truncated, _info = step_result
                    done = bool(terminated or truncated)
                else:
                    next_obs, reward, done, _info = step_result

                ep_ret += float(reward)
                ep_len += 1
                if not done:
                    stacked = stack_frames(stacked, next_obs, is_new_episode=False, num_stack=frame_stack)

            returns.append(ep_ret)
            lengths.append(ep_len)
    finally:
        env.close()
        if q_was_training:
            q_net.train()

    return float(np.mean(returns)), float(np.mean(lengths))


def _make_json_serializable(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    if isinstance(obj, dict):
        return {k: _make_json_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_make_json_serializable(item) for item in obj]
    if isinstance(obj, (np.int64, np.int32, np.int16, np.int8)):
        return int(obj)
    if isinstance(obj, (np.float64, np.float32, np.float16)):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj


def _compute_total_timesteps_from_frames(config: dict) -> int:
    """Translate paper-style total_frames into env.step() calls."""
    if config.get("total_timesteps") is not None:
        return int(config["total_timesteps"])
    total_frames = int(config.get("total_frames", int(100e6)))
    frame_skip = int(config.get("frame_skip", 5))
    return int(np.ceil(total_frames / max(1, frame_skip)))


def run_one_seed(base_config: dict, run_seed: int, *, outer_run_dir: str):
    config = dict(base_config)
    config["seed"] = int(run_seed)

    # Set random seeds
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])

    # Create env to determine action space
    test_env = make_atari(
        config["env_name"],
        max_episode_steps=int(config.get("max_episode_steps", 4500)),
        use_max_pool=bool(config.get("use_max_pool", False)),
        frame_skip=int(config.get("frame_skip", 5)),
        sticky_action_prob=float(config.get("sticky_action_prob", 0.25)),
        full_action_space=bool(config.get("full_action_space", True)),
        noop_on_reset=bool(config.get("noop_on_reset", True)),
        noop_max=int(config.get("noop_max", 30)),
    )
    config["n_actions"] = int(test_env.action_space.n)
    test_env.close()

    # Determine training horizon
    config["total_timesteps"] = _compute_total_timesteps_from_frames(config)

    # Create per-seed run directory under a shared outer directory.
    # Layout:
    # runs/<outer_time>-<env_tag>/seed<seed>-<seed_start_time>/{config,checkpoint,log,GIF}
    seed_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    seed_dir = os.path.join(outer_run_dir, f"seed{config['seed']}-{seed_time}")
    cfg_dir = os.path.join(seed_dir, "config")
    ckpt_dir = os.path.join(seed_dir, "checkpoint")
    log_dir = os.path.join(seed_dir, "log")
    gif_dir = os.path.join(seed_dir, "GIF")
    for d in (cfg_dir, ckpt_dir, log_dir, gif_dir):
        os.makedirs(d, exist_ok=True)

    # Save config
    try:
        config_serializable = _make_json_serializable(config)
        with open(os.path.join(cfg_dir, "config.json"), "w", encoding="utf-8") as f:
            json.dump(config_serializable, f, indent=2)
    except Exception as e:
        print(f"[Warning] Failed to save config: {e}")

    writer = SummaryWriter(os.path.join(log_dir, "tb"))
    agent = FPVRAgent(config["state_shape"], config["n_actions"], config)

    # Load checkpoint if continuing
    start_step = 0
    episode_count = 0
    positive_reward_count = 0
    running_reward = 0.0
    last_episode_reward = 0.0
    if not config.get("train_from_scratch", True):
        ckpt_path = os.path.join(ckpt_dir, "checkpoint.pth")
        if os.path.exists(ckpt_path):
            loaded_config = agent.load(ckpt_path)
            start_step = int(agent.train_step_count)
            episode_count = int(loaded_config.get("episode_count", 0))
            positive_reward_count = int(loaded_config.get("positive_reward_count", 0))
            running_reward = float(loaded_config.get("running_reward", 0.0))
        else:
            print(f"[Resume] No checkpoint found at {ckpt_path}; starting from scratch.")

    # Note: fixed-state c mode removed.

    # Initialize GIF recording state
    recording_gif = False
    if int(config.get("gif_interval", 0)) > 0:
        next_episode = episode_count + 1
        recording_gif = (next_episode % int(config["gif_interval"]) == 0)

    # Create training environment
        env = make_atari(
        config["env_name"],
        max_episode_steps=int(config.get("max_episode_steps", 4500)),
        use_max_pool=bool(config.get("use_max_pool", False)),
        frame_skip=int(config.get("frame_skip", 5)),
        sticky_action_prob=float(config.get("sticky_action_prob", 0.25)),
        full_action_space=bool(config.get("full_action_space", True)),
        noop_on_reset=bool(config.get("noop_on_reset", True)),
        noop_max=int(config.get("noop_max", 30)),
    )

    try:
        reset_result = env.reset(seed=config["seed"])
        obs = reset_result[0] if isinstance(reset_result, tuple) else reset_result
        frame_stack = int(config["frame_stack"])
        stacked_state = stack_frames(
            np.zeros((frame_stack, 84, 84), dtype=np.uint8),
            obs,
            is_new_episode=True,
            num_stack=frame_stack,
        )

        if config.get("reset_c", False):
            agent.reset_c()

        episode_reward = 0.0
        episode_length = 0

        gif_frames = []
        if recording_gif:
            gif_frames.append(stacked_state[-1].copy())

        print("\n" + "=" * 80)
        print(f"Starting FPVR-DQN Training (seed={config['seed']}, total_timesteps={config['total_timesteps']})")
        print("=" * 80)

        last_losses = None

        for step in tqdm(range(start_step, int(config["total_timesteps"])), desc=f"Training(seed={config['seed']})"):
            # ========== Select Action ==========
            eps_start = float(config.get("eps_start", 1.0))
            eps_end = float(config.get("eps_end", 0.1))
            # Epsilon decay is defined in ALE frames (paper: 1M frames), not agent steps.
            # One agent step corresponds to ~frame_skip ALE frames (either via wrapper repeats or ALE frameskip).
            frame_skip = int(config.get("frame_skip", 5))
            step_frames = int(step) * max(1, frame_skip)
            eps_decay_frames = config.get("eps_decay_frames", int(1e6))
            # Backward-compat: if eps_decay_steps is provided, treat it as frames.
            if config.get("eps_decay_steps", None) is not None:
                eps_decay_frames = int(config["eps_decay_steps"])
            eps_decay_frames = int(eps_decay_frames)

            if step_frames < eps_decay_frames:
                epsilon = eps_start - (eps_start - eps_end) * (step_frames / max(1, eps_decay_frames))
            else:
                epsilon = eps_end

            action, phi_tilde = agent.select_action(stacked_state, epsilon=epsilon)

            # GIF frame collection
            if recording_gif:
                gif_frames.append(stacked_state[-1].copy())

            # ========== Environment Step ==========
            step_result = env.step(action)
            if len(step_result) == 5:
                next_obs, reward, terminated, truncated, info = step_result
                done = bool(terminated or truncated)
            else:
                next_obs, reward, done, info = step_result

            # DQN-style reward clipping for learning targets (store clipped reward in replay buffer).
            # Keep `reward` as the raw environment reward for episode returns/logging comparability.
            reward_clipping = str(config.get("reward_clipping", "sign"))
            if reward_clipping not in ("none", "sign", "clip"):
                reward_clipping = "sign"
            if reward_clipping == "none":
                reward_store = float(reward)
            elif reward_clipping == "clip":
                reward_store = float(np.clip(float(reward), -1.0, 1.0))
            else:
                # sign: {-1, 0, +1}
                r = float(reward)
                reward_store = 1.0 if r > 0.0 else (-1.0 if r < 0.0 else 0.0)

            if reward > 0:
                positive_reward_count += 1

            next_stacked_state = stack_frames(stacked_state, next_obs, is_new_episode=False, num_stack=frame_stack)

            # ========== Store Experience ==========
            agent.replay_buffer.store(stacked_state, action, reward_store, next_stacked_state, done)

            # ========== Update c vector ==========
            agent.update_c(phi_tilde)

            # ========== Train ==========
            if step >= int(config["learning_starts"]) and step % int(config["train_freq"]) == 0:
                for _ in range(int(config["gradient_steps"])):
                    last_losses = agent.train()
                    if last_losses and step % int(config["interval"]) == 0:
                        # Log all available loss components (SR/Q/etc.) if provided by agent.train()
                        for k, v in last_losses.items():
                            if isinstance(v, (int, float, np.floating, np.integer)):
                                writer.add_scalar(f"Loss/{k}", float(v), step)

            # ========== Periodic Evaluation (Q-network only, epsilon-greedy) ==========
            eval_interval = int(config.get("eval_interval", 0))
            if (
                eval_interval > 0
                and (step % eval_interval == 0)
                and (step > 0)
                and getattr(agent, "q_net", None) is not None
            ):
                eval_eps = float(config.get("eval_epsilon", 0.01))
                eval_n = int(config.get("eval_episodes", 10))
                res = evaluate_q_eps_greedy(
                    agent.q_net,
                    config["env_name"],
                    eps=eval_eps,
                    n_episodes=eval_n,
                    seed=int(config.get("seed", 47)) + 12345 + int(step),
                    frame_stack=int(config.get("frame_stack", 4)),
                    max_episode_steps=int(config.get("max_episode_steps", 4500)),
                    use_max_pool=bool(config.get("use_max_pool", False)),
                    frame_skip=int(config.get("frame_skip", 5)),
                    sticky_action_prob=float(config.get("sticky_action_prob", 0.25)),
                    full_action_space=bool(config.get("full_action_space", True)),
                    noop_on_reset=bool(config.get("noop_on_reset", True)),
                    noop_max=int(config.get("noop_max", 30)),
                    device=agent.device,
                )
                if res is not None:
                    mean_ret, mean_len = res
                    print(
                        f"\n[Eval@Step {step}] Q ε-greedy (eps={eval_eps:.3f}) -> "
                        f"Return: {mean_ret:.2f}, Length: {mean_len:.1f} (episodes={eval_n})"
                    )
                    writer.add_scalar("Eval/Q_eps_return", mean_ret, step)
                    writer.add_scalar("Eval/Q_eps_length", mean_len, step)
                    writer.add_scalar("Eval/Q_eval_epsilon", eval_eps, step)
                    # Also log evaluation curves against frames (paper reports in frames).
                    frame_skip = int(config.get("frame_skip", 5))
                    step_frames = int(step) * max(1, frame_skip)
                    writer.add_scalar("EvalFrames/Q_eps_return", mean_ret, step_frames)
                    writer.add_scalar("EvalFrames/Q_eps_length", mean_len, step_frames)

            # ========== Episode Tracking ==========
            episode_reward += float(reward)
            episode_length += 1

            if done:
                episode_count += 1
                running_reward = 0.9 * running_reward + 0.1 * episode_reward
                last_episode_reward = float(episode_reward)

                if recording_gif:
                    if len(gif_frames) > 0:
                        gif_frames.append(next_stacked_state[-1].copy())
                        gif_path = os.path.join(gif_dir, f"episode_{episode_count:06d}.gif")
                        _ = save_gif(gif_frames, gif_path, fps=int(config["gif_fps"]))
                    recording_gif = False
                    gif_frames = []

                writer.add_scalar("Episode/Reward", episode_reward, episode_count)
                writer.add_scalar("Episode/Length", episode_length, episode_count)
                writer.add_scalar("Episode/Running_Reward", running_reward, episode_count)
                writer.add_scalar("Metrics/Positive_Reward_Count", positive_reward_count, episode_count)

                if episode_count % 10 == 0:
                    print(
                        f"\n[Episode {episode_count}] Reward: {episode_reward:.2f}, Length: {episode_length}, "
                        f"Running: {running_reward:.2f}, Positive Rewards: {positive_reward_count}"
                    )

                if config.get("reset_c", False):
                    agent.reset_c()

                reset_result = env.reset()
                obs = reset_result[0] if isinstance(reset_result, tuple) else reset_result
                stacked_state = stack_frames(
                    np.zeros((frame_stack, 84, 84), dtype=np.uint8),
                    obs,
                    is_new_episode=True,
                    num_stack=frame_stack,
                )
                episode_reward = 0.0
                episode_length = 0

                next_episode = episode_count + 1
                if int(config.get("gif_interval", 0)) > 0 and (next_episode % int(config["gif_interval"]) == 0):
                    recording_gif = True
                    gif_frames = [stacked_state[-1].copy()]
            else:
                stacked_state = next_stacked_state

            # ========== Logging ==========
            if step % int(config["interval"]) == 0 and step > 0:
                recording_str = f", Recording GIF: {recording_gif}" if int(config.get("gif_interval", 0)) > 0 else ""
                last_reward_str = f", Last Episode Reward: {last_episode_reward:.2f}" if episode_count > 0 else ""
                print(
                    f"\n[Step {step}] Episodes: {episode_count}, Buffer: {len(agent.replay_buffer)}, "
                    f"Epsilon: {epsilon:.3f}, Running Reward: {running_reward:.2f}, "
                    f"Positive Rewards: {positive_reward_count}{last_reward_str}{recording_str}"
                )
            # Note: adaptive alpha removed. sum_zscore/filtered_zscore use fixed policy_alpha.

                with torch.no_grad():
                    state_np = stacked_state if stacked_state.ndim == 4 else stacked_state[None]
                    state_t = torch.from_numpy(state_np).to(agent.device)
                    state_t_sf = agent._sf_input(state_t)
                    phi_raw_log, _ = agent.network(state_t_sf)
                    phi_tilde_log = agent._apply_whitening(phi_raw_log)
                    _, psi_all_log = agent.network(state_t_sf, phi_whitened=phi_tilde_log)
                    c_used_log = agent.c_vec[0:1]
                    redundancy_raw = agent._compute_redundancy(psi_all_log, c_used_log)
                    redundancy_values = redundancy_raw.squeeze(0).cpu().numpy()
                print(f"  Raw redundancy (no z-score): {np.array2string(redundancy_values, precision=4)}")

                if getattr(agent, "q_net", None) is not None:
                    with torch.no_grad():
                        q_vals = agent.q_net(state_t).squeeze(0).cpu().numpy()
                    print(f"  Q-values: {np.array2string(q_vals, precision=4)}")

                writer.add_scalar("Metrics/Epsilon", epsilon, step)
                writer.add_scalar("Metrics/Buffer_Size", len(agent.replay_buffer), step)
                writer.add_scalar("Metrics/Positive_Reward_Count", positive_reward_count, step)

                if last_losses is not None and step >= int(config["learning_starts"]):
                    # SF/SR network loss (primary)
                    if "sr_loss" in last_losses:
                        print(f"  SF(SR) Loss: {float(last_losses['sr_loss']):.4f}")
                    # Optionally print any other losses returned by agent.train()
                    extra_loss_keys = [k for k in last_losses.keys() if k not in ("sr_loss", "total_loss")]
                    if "total_loss" in last_losses:
                        print(f"  Total Loss: {float(last_losses['total_loss']):.4f}")
                    for k in extra_loss_keys:
                        v = last_losses.get(k)
                        if isinstance(v, (int, float, np.floating, np.integer)):
                            print(f"  {k}: {float(v):.4f}")

            # ========== Save Checkpoint ==========
            # Note: `step` is 0-based loop index. After executing this iteration, we've completed (step + 1) env steps.
            # So to save at exactly N env steps (e.g., 1e6, 2e6), we should check (step + 1) % save_interval == 0.
            save_every = int(config["save_interval"])
            env_steps_done = int(step) + 1
            if save_every > 0 and (env_steps_done % save_every == 0):
                extra_state = {
                    "episode_count": episode_count,
                    "positive_reward_count": positive_reward_count,
                    "running_reward": running_reward,
                }
                ckpt_path = os.path.join(ckpt_dir, f"checkpoint_step_{env_steps_done}.pth")
                agent.save(ckpt_path, extra_state=extra_state)
                agent.save(os.path.join(ckpt_dir, "checkpoint.pth"), extra_state=extra_state)
    finally:
        env.close()
        writer.close()
        print("\n" + "=" * 80)
        print(f"Training Complete! (seed={config['seed']})")
        print("=" * 80)


def main():
    base_config = get_params()
    base_seed = int(base_config.get("seed", 47))
    num_seeds = int(base_config.get("num_seeds", 1))

    # Outer run directory: time + env name
    outer_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    env_tag = base_config["env_name"].replace("/", "_")
    base_dir = os.path.dirname(os.path.abspath(__file__))
    outer_run_dir = os.path.join(base_dir, "runs", f"{outer_time}-{env_tag}")
    os.makedirs(outer_run_dir, exist_ok=True)

    for seed_idx in range(num_seeds):
        run_seed = base_seed + seed_idx
        run_one_seed(base_config, run_seed, outer_run_dir=outer_run_dir)


if __name__ == "__main__":
    main()

