{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eb763ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from diffusion_co_design.common import get_latest_model, cuda\n",
    "\n",
    "from diffusion_co_design.common.design import DesignerParams\n",
    "from diffusion_co_design.vmas.model.rl import create_critic\n",
    "from diffusion_co_design.vmas.schema import (\n",
    "    TrainingConfig,\n",
    "    EnvCriticConfig,\n",
    "    DiffusionOperation,\n",
    ")\n",
    "from diffusion_co_design.vmas.design import DicodeDesigner, RandomDesigner\n",
    "from diffusion_co_design.vmas.scenario.env import create_env, render_layout\n",
    "\n",
    "device = cuda\n",
    "# Local\n",
    "# training_dir = (\n",
    "# )\n",
    "# Global\n",
    "training_dir = ()\n",
    "\n",
    "\n",
    "cfg = TrainingConfig.from_file(os.path.join(training_dir, \".hydra\", \"config.yaml\"))\n",
    "\n",
    "baseline = RandomDesigner(DesignerParams.placeholder(cfg.scenario))\n",
    "\n",
    "designer = DicodeDesigner(\n",
    "    designer_setting=DesignerParams.placeholder(cfg.scenario),\n",
    "    classifier=EnvCriticConfig(\n",
    "        depth=cfg.policy.critic.depth,\n",
    "        hidden_size=cfg.policy.critic.hidden_size,\n",
    "        k=cfg.policy.critic.k,\n",
    "    ),\n",
    "    diffusion=DiffusionOperation(\n",
    "        num_recurrences=8,\n",
    "        backward_lr=0.01,\n",
    "        backward_steps=16,\n",
    "        forward_guidance_wt=5,\n",
    "        forward_guidance_annealing=False,\n",
    "    ),\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "ref_env = create_env(\n",
    "    mode=\"reference\",\n",
    "    scenario=cfg.scenario,\n",
    "    designer=RandomDesigner(DesignerParams.placeholder(cfg.scenario)).get_placeholder(),\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "critic = create_critic(\n",
    "    env=ref_env, scenario=cfg.scenario, cfg=cfg.policy.critic, device=device\n",
    ")\n",
    "\n",
    "state_dict = torch.load(\n",
    "    get_latest_model(dir=os.path.join(training_dir, \"checkpoints\"), prefix=\"critic_\")\n",
    ")\n",
    "\n",
    "designer.model.load_state_dict(critic.module.state_dict())\n",
    "designer.model = designer.model.to(device)\n",
    "\n",
    "baseline_envs = torch.stack(baseline.generate_layout_batch(batch_size=9)).to(device)\n",
    "envs = torch.stack(designer.generate_layout_batch(batch_size=9))\n",
    "\n",
    "print(designer.model(baseline_envs).mean())\n",
    "print(designer.model(envs).mean())\n",
    "\n",
    "fig, axes = plt.subplots(2, 9, figsize=(18, 4))\n",
    "for i in range(9):\n",
    "    axes[0, i].imshow(render_layout(x=baseline_envs[i], scenario=cfg.scenario))\n",
    "    axes[0, i].axis(\"off\")\n",
    "    axes[1, i].imshow(render_layout(x=envs[i], scenario=cfg.scenario))\n",
    "    axes[1, i].axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "627de771",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
