{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eb763ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.train_rware.main import TrainingConfig\n",
    "from diffusion_co_design.common import OUTPUT_DIR, get_latest_model, cuda\n",
    "\n",
    "from diffusion_co_design.rware.model.classifier import make_model\n",
    "from diffusion_co_design.rware.diffusion.transform import (\n",
    "    storage_to_layout,\n",
    "    graph_projection_constraint,\n",
    "    image_projection_constraint,\n",
    ")\n",
    "from diffusion_co_design.rware.diffusion.generator import Generator, OptimizerDetails\n",
    "from rware.warehouse import Warehouse\n",
    "\n",
    "from dataset import load_dataset, make_dataloader\n",
    "\n",
    "\n",
    "device = cuda\n",
    "representation = \"graph\"\n",
    "\n",
    "# Load latest model and config\n",
    "cfg = TrainingConfig.from_file(os.path.join(training_dir, \".hydra\", \"config.yaml\"))\n",
    "diffusion_dir = pretrain_dir = os.path.join(\n",
    "    OUTPUT_DIR, \"rware\", \"diffusion\", representation, cfg.scenario.name\n",
    ")\n",
    "latest_diffusion_checkpoint = get_latest_model(diffusion_dir, \"model\")\n",
    "\n",
    "# Make Dataset\n",
    "train_dataset, eval_dataset = load_dataset(\n",
    "    scenario=cfg.scenario,\n",
    "    training_dir=training_dir,\n",
    "    dataset_size=10_000,\n",
    "    num_workers=25,\n",
    "    test_proportion=0.2,\n",
    "    recompute=False,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "\n",
    "train_dataloader = make_dataloader(\n",
    "    train_dataset,\n",
    "    scenario=cfg.scenario,\n",
    "    batch_size=128,\n",
    "    representation=representation,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "eval_dataloader = make_dataloader(\n",
    "    eval_dataset,\n",
    "    scenario=cfg.scenario,\n",
    "    batch_size=128,\n",
    "    representation=representation,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "match representation:\n",
    "    case \"graph\":\n",
    "        model = make_model(\n",
    "            \"gnn-cnn\",\n",
    "            cfg.scenario,\n",
    "            model_kwargs={\"add_goal_positions\": False},\n",
    "            device=device,\n",
    "        )\n",
    "\n",
    "        model.load_state_dict(torch.load())\n",
    "    case \"image\":\n",
    "        model = make_model(\n",
    "            \"cnn\",\n",
    "            cfg.scenario,\n",
    "            device=device,\n",
    "        )\n",
    "\n",
    "        model.load_state_dict(torch.load())\n",
    "\n",
    "\n",
    "def show_batch(\n",
    "    environment_batch,\n",
    "    representation,\n",
    "    n: int = 8,\n",
    "):\n",
    "    layouts = []\n",
    "    for theta in environment_batch:\n",
    "        layout = storage_to_layout(theta, cfg.scenario, representation=representation)\n",
    "        warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "        layouts.append(warehouse.render())\n",
    "        warehouse.close()\n",
    "\n",
    "    fig, axs = plt.subplots(3, 3, figsize=(12, 12))\n",
    "    axs = axs.ravel()\n",
    "    for ax in axs:\n",
    "        ax.axis(\"off\")\n",
    "    for i in range(n):\n",
    "        axs[i].imshow(layouts[i])\n",
    "    return fig, axs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a0f514b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from diffusion_co_design.rware.design import ClassifierConfig, GradientDescentDesigner\n",
    "\n",
    "# classifier = ClassifierConfig(name=\"gnn-cnn\", representation=\"graph\")\n",
    "# grad_designer = GradientDescentDesigner(\n",
    "#     scenario=cfg.scenario,\n",
    "#     classifier=classifier,\n",
    "#     device=\"cuda\",\n",
    "#     epochs=20,\n",
    "#     gradient_lr=0.03,\n",
    "# )\n",
    "# grad_designer.model = model\n",
    "\n",
    "# generated_environments = grad_designer._reset_env_buffer(9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc9b5368",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show_batch(generated_environments, representation=\"graph\")\n",
    "# pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3620e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from diffusion_co_design.rware.design import PolicyDesigner\n",
    "\n",
    "# designer = PolicyDesigner(scenario=cfg.scenario)\n",
    "\n",
    "# envs = designer.generate_environment_batch(9)\n",
    "# env = envs[0]\n",
    "# env.sum(dim=0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ebff17a",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "663cf930",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n",
    "axs[0].hist(\n",
    "    eval_dataset.dataset.env_returns.storage[\"episode_reward\"],\n",
    "    bins=100,\n",
    ")\n",
    "axs[0].set_title(\"Episode Reward\")\n",
    "axs[1].hist(\n",
    "    eval_dataset.dataset.env_returns.storage[\"expected_reward\"],\n",
    "    bins=100,\n",
    ")\n",
    "axs[1].set_title(\"Expected Reward (Critic)\")\n",
    "\n",
    "classifier_returns = []\n",
    "with torch.no_grad():\n",
    "    for x, _, _ in eval_dataloader:\n",
    "        classifier_returns.append(model(x.to(device)))\n",
    "classifier_returns = torch.cat(classifier_returns, dim=0)\n",
    "axs[2].hist(\n",
    "    classifier_returns.cpu().numpy(),\n",
    "    bins=100,\n",
    ")\n",
    "axs[2].set_title(\"Classifier Returns\")\n",
    "\n",
    "\n",
    "print(\n",
    "    \"Mean Episode Reward: \",\n",
    "    train_dataset.dataset.env_returns.storage[\"episode_reward\"].mean(),\n",
    ")\n",
    "print(\n",
    "    \"Mean Expected Reward: \",\n",
    "    train_dataset.dataset.env_returns.storage[\"expected_reward\"].mean(),\n",
    ")\n",
    "print(\n",
    "    \"Mean Classifier Returns: \",\n",
    "    classifier_returns.mean(),\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a99fd0ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Generator(\n",
    "    batch_size=10,\n",
    "    generator_model_path=latest_diffusion_checkpoint,\n",
    "    scenario=cfg.scenario,\n",
    "    guidance_wt=200 if representation == \"image\" else 5.0,\n",
    "    representation=representation,\n",
    "    device=device,\n",
    ")\n",
    "guidance_model = model\n",
    "guidance_model.eval()\n",
    "\n",
    "operation = OptimizerDetails()\n",
    "match representation:\n",
    "    case \"graph\":\n",
    "        operation.lr = 0.01\n",
    "        operation.num_recurrences = 8\n",
    "        operation.backward_steps = 16\n",
    "        operation.projection_constraint = graph_projection_constraint(cfg.scenario)\n",
    "    case \"image\":\n",
    "        operation.num_recurrences = 8\n",
    "        operation.backward_steps = 0\n",
    "        operation.projection_constraint = image_projection_constraint(cfg.scenario)\n",
    "\n",
    "\n",
    "environment_batch = generator.generate_batch(\n",
    "    value=guidance_model,\n",
    "    use_operation=True,\n",
    "    operation_override=operation,\n",
    ")\n",
    "\n",
    "\n",
    "for env in environment_batch:\n",
    "    layout = storage_to_layout(env, cfg.scenario, representation=representation)\n",
    "    print(len(layout.reset_shelves()))\n",
    "fig, axs = show_batch(environment_batch, representation)\n",
    "fig.suptitle(\"Guided Generation\")\n",
    "fig.tight_layout()\n",
    "\n",
    "X_batch = torch.from_numpy(environment_batch).to(device=device, dtype=torch.float32)\n",
    "match representation:\n",
    "    case \"graph\":\n",
    "        X_batch = (X_batch / (cfg.scenario.size - 1)) * 2 - 1\n",
    "    case \"image\":\n",
    "        X_batch = X_batch * 2 - 1\n",
    "\n",
    "print(X_batch.shape)\n",
    "print(guidance_model(X_batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ff26d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34c677a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FIGURE_SIZE_CNST = 2.5\n",
    "\n",
    "# test_layout = [next(iter(train_dataset))]\n",
    "# test_layout, _ = collate_fn(test_layout)\n",
    "# pos, color = test_layout\n",
    "\n",
    "# pos.requires_grad = True\n",
    "# pos_optim = torch.optim.Adam([pos], lr=0.01)\n",
    "\n",
    "# constraint = graph_projection_constraint(cfg.scenario)\n",
    "\n",
    "# n_iterations = 1000\n",
    "# for iteration in range(n_iterations):\n",
    "#     pos.requires_grad = True\n",
    "#     pos_optim.zero_grad()\n",
    "#     y_pred = model.predict((pos, color))\n",
    "#     loss = -y_pred.mean()\n",
    "#     loss.backward()\n",
    "#     pos_optim.step()\n",
    "\n",
    "#     if iteration % (n_iterations // 10) == 0:\n",
    "#         print(f\"Iteration {iteration} Loss: {loss.item()}\")\n",
    "#         # pos = constraint(pos.detach())\n",
    "\n",
    "#         fig, ax = plt.subplots(figsize=(FIGURE_SIZE_CNST, FIGURE_SIZE_CNST))\n",
    "\n",
    "#         show_pos = (pos.squeeze() + 1) / 2\n",
    "#         show_pos = show_pos * cfg.scenario.size\n",
    "#         layout = storage_to_layout(\n",
    "#             features=show_pos.numpy(force=True),\n",
    "#             config=cfg.scenario,\n",
    "#             representation_override=\"graph\",\n",
    "#         )\n",
    "#         warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "#         print(len(warehouse.shelves))\n",
    "#         im = warehouse.render()\n",
    "#         ax.imshow(im)\n",
    "#         ax.axis(\"off\")\n",
    "#         plt.show()\n",
    "#         warehouse.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "487ad4cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusion_co_design.rware.model.classifier import GNNClassifier\n",
    "\n",
    "model = GNNClassifier(cfg=cfg.scenario).to(device=device)\n",
    "\n",
    "# Test\n",
    "(pos, color), y = next(iter(train_dataloader))\n",
    "\n",
    "number_parameters = sum([p.numel() for p in model.parameters()])\n",
    "print(f\"Number of parameters: {number_parameters}\")\n",
    "assert model.predict((pos, color)).shape == y.shape\n",
    "pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
