{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6d30aa9d",
   "metadata": {},
   "source": [
    "# PPO Forward Pass Wall-Clock Benchmark\n",
    "\n",
    "Benchmarking the forward pass (policy + value) of `DiscretePPOAgent` using a dummy MiniGrid-like batch (batch size 2048).\n",
    "\n",
    "Steps:\n",
    "1. Create a `DiscretePPOAgent`.\n",
    "4. Generate a dummy observation with ProcGen-like shapes and batch of size 2048.\n",
    "5. Run a warm-up forward pass.\n",
    "6. Benchmark with `%%timeit`.\n",
    "\n",
    "The benchmark measures the time to compute: `get_action_and_value`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf1a9a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import time\n",
    "\n",
    "import torch\n",
    "from hydra import compose, initialize\n",
    "from hydra.core.global_hydra import GlobalHydra\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from src.rl.agents.ppo_discrete import DiscretePPOAgent\n",
    "from src.rl.environments.make_functions import make_procgen\n",
    "from src.rl.utils.train import set_cuda_configuration, set_seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6308c587",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_DIR = \"config/procgen_paper\"\n",
    "CONFIG_NAME = \"euclidean_baseline\"  # hyper_paper, hyperpp\n",
    "GPU = 1\n",
    "WARMUP_STEPS = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f86c449",
   "metadata": {},
   "outputs": [],
   "source": [
    "if GlobalHydra.instance().is_initialized():\n",
    "    GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=CONFIG_DIR, job_name=\"agent_timing\")\n",
    "cfg = compose(\n",
    "    config_name=CONFIG_NAME,\n",
    "    overrides=[\n",
    "        \"experiment.seed=23\",\n",
    "        \"hydra.searchpath=[config]\",\n",
    "    ],\n",
    ")\n",
    "print(OmegaConf.to_yaml(cfg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1cdc8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Logging setup\n",
    "logging.basicConfig(level=cfg.logging_level, format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
    "\n",
    "# Derived fields\n",
    "cfg.batch_size = int(cfg.num_envs * cfg.num_steps)\n",
    "cfg.minibatch_size = int(cfg.batch_size // cfg.num_minibatches)\n",
    "cfg.num_iterations = cfg.total_timesteps // cfg.batch_size\n",
    "run_name = f\"{cfg.env_id}__{cfg.experiment.exp_name}__{cfg.experiment.seed}__{int(time.time())}\"\n",
    "cfg.experiment.run_name = run_name\n",
    "\n",
    "# Seeds and device\n",
    "set_seeds(cfg.experiment.seed, torch_deterministic=cfg.experiment.torch_deterministic)\n",
    "device = set_cuda_configuration(GPU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe588418",
   "metadata": {},
   "outputs": [],
   "source": [
    "envs = make_procgen(\n",
    "    env_id=cfg.env_id,\n",
    "    num_envs=cfg.num_envs,\n",
    "    level_distribution=cfg.level_distribution,\n",
    "    start_level=0,\n",
    "    num_levels=cfg.num_levels,\n",
    "    capture_video=cfg.experiment.capture_video,\n",
    "    gamma=cfg.gamma,\n",
    "    run_name=run_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c630ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = DiscretePPOAgent(\n",
    "    env_type=cfg.env_type,\n",
    "    envs=envs,\n",
    "    gamma=cfg.gamma,\n",
    "    num_steps=cfg.num_steps,\n",
    "    gae_lambda=cfg.gae_lambda,\n",
    "    batch_size=cfg.batch_size,\n",
    "    minibatch_size=cfg.minibatch_size,\n",
    "    update_epochs=cfg.update_epochs,\n",
    "    clip_coef=cfg.clip_coef,\n",
    "    ent_coef=cfg.ent_coef,\n",
    "    vf_coef=cfg.vf_coef,\n",
    "    max_grad_norm=cfg.max_grad_norm,\n",
    "    target_kl=cfg.target_kl,\n",
    "    norm_adv=cfg.norm_adv,\n",
    "    embedding_dim=cfg.embedding_dim,\n",
    "    shared_encoder=cfg.shared_encoder,\n",
    "    last_layer_tanh=cfg.last_layer_tanh,\n",
    "    feat_reg_coef=cfg.feat_reg_coef,\n",
    "    compute_embedding_metrics=cfg.compute_embedding_metrics,\n",
    "    actor_cfg=cfg.policy,\n",
    "    critic_cfg=cfg.value_fn,\n",
    "    optim_cfg=cfg.optimizer,\n",
    "    device=device,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "656622c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_shape = envs.observation_space.sample()[\"rgb\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60e9ab59",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = torch.randint(0, 255, (cfg.minibatch_size, *obs_shape), dtype=torch.uint8).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc76bf8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(WARMUP_STEPS):\n",
    "    _ = agent.get_action_and_value(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7d66869",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a19d4ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "_ = agent.get_action_and_value(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ed8ceb1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
