{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import numpy as np\n",
    "from IPython.display import Video\n",
    "from tqdm import tqdm\n",
    "from gymnasium.utils.save_video import save_video\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tensordict import TensorDict\n",
    "from torchrl.envs import EnvBase\n",
    "from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement\n",
    "from matplotlib import pyplot as plt\n",
    "from guided_diffusion.script_util import create_classifier\n",
    "from diffusion_co_design.rware.diffusion.transform import storage_to_layout\n",
    "from rware.warehouse import Warehouse\n",
    "from diffusion_co_design.rware.env import create_env, create_batched_env\n",
    "from diffusion_co_design.rware.model.rl import rware_models\n",
    "from diffusion_co_design.rware.schema import TrainingConfig, DesignerConfig\n",
    "from diffusion_co_design.rware.design import DesignerRegistry\n",
    "from diffusion_co_design.common import OUTPUT_DIR, cuda, get_latest_model\n",
    "from diffusion_co_design.rware.diffusion.generator import Generator, OptimizerDetails\n",
    "\n",
    "FIGURE_SIZE_CNST = 3\n",
    "CACHE_ID = \"train_visualisation\"\n",
    "RECOMPUTE = False\n",
    "\n",
    "# Parameters\n",
    "\n",
    "device = cuda\n",
    "\n",
    "# Get latest policy\n",
    "checkpoint_dir = os.path.join(training_dir, \"checkpoints\")\n",
    "latest_policy = get_latest_model(checkpoint_dir, \"policy_\")\n",
    "latest_critic = get_latest_model(checkpoint_dir, \"critic_\")\n",
    "\n",
    "# Get config\n",
    "hydra_dir = os.path.join(training_dir, \".hydra\")\n",
    "cfg = TrainingConfig.from_file(os.path.join(hydra_dir, \"config.yaml\"))\n",
    "training_config = os.path.join(hydra_dir, \"config.yaml\")\n",
    "\n",
    "# Create environment\n",
    "cache_dir = os.path.join(OUTPUT_DIR, \".tmp\", CACHE_ID)\n",
    "if RECOMPUTE and os.path.exists(cache_dir):\n",
    "    shutil.rmtree(cache_dir)\n",
    "if not os.path.exists(cache_dir):\n",
    "    os.makedirs(cache_dir)\n",
    "\n",
    "\n",
    "master_designer, env_designer = DesignerRegistry.get(\n",
    "    DesignerConfig(type=\"random\"),\n",
    "    cfg.scenario,\n",
    "    cache_dir,\n",
    "    environment_batch_size=32,\n",
    "    device=device,\n",
    ")\n",
    "agent_idxs = cfg.scenario.agent_idxs\n",
    "goal_idxs = cfg.scenario.goal_idxs\n",
    "env = create_env(cfg.scenario, env_designer, render=True, device=device)\n",
    "policy, critic = rware_models(env, cfg.policy, device=device)\n",
    "policy.load_state_dict(torch.load(latest_policy))\n",
    "critic.load_state_dict(torch.load(latest_critic))\n",
    "critic.to(\"cpu\")\n",
    "\n",
    "\n",
    "def view_video(env: EnvBase, policy):\n",
    "    frames = []\n",
    "    video_out = os.path.join(cache_dir, \"video/rl-video-episode-0.mp4\")\n",
    "\n",
    "    def append_frames(env, td):\n",
    "        return frames.append(env.render())\n",
    "\n",
    "    env.rollout(\n",
    "        max_steps=cfg.scenario.max_steps,\n",
    "        policy=policy,\n",
    "        callback=append_frames,\n",
    "        auto_cast_to_device=True,\n",
    "    )\n",
    "\n",
    "    save_video(\n",
    "        frames=frames,\n",
    "        video_folder=os.path.join(cache_dir, \"video\"),\n",
    "        fps=10,\n",
    "    )\n",
    "\n",
    "    return lambda: Video(filename=video_out, embed=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test if the value function is able to discern good environments!\n",
    "\n",
    "NUM_PARALLEL_COLLECTION = 25\n",
    "DATASET_SIZE = 10_000\n",
    "# DATASET_SIZE = 2500\n",
    "# DATASET_SIZE = 4\n",
    "BATCH_SIZE = 128\n",
    "TEST_PROPORTION = 0.2\n",
    "\n",
    "collection_env = create_batched_env(\n",
    "    num_environments=NUM_PARALLEL_COLLECTION,\n",
    "    scenario=cfg.scenario,\n",
    "    designer=env_designer,\n",
    "    is_eval=False,\n",
    "    device=\"cpu\",\n",
    ")\n",
    "\n",
    "env_returns = ReplayBuffer(\n",
    "    storage=LazyTensorStorage(max_size=DATASET_SIZE),\n",
    "    sampler=SamplerWithoutReplacement(),\n",
    "    batch_size=BATCH_SIZE,\n",
    ")\n",
    "\n",
    "\n",
    "if RECOMPUTE:\n",
    "    discount = (cfg.ppo.gamma ** torch.linspace(0, 499, 500)).view(1, 500, 1, 1)\n",
    "    for _ in tqdm(range(DATASET_SIZE // NUM_PARALLEL_COLLECTION)):\n",
    "        with torch.no_grad():\n",
    "            rollout = collection_env.rollout(\n",
    "                max_steps=cfg.scenario.max_steps,\n",
    "                policy=policy,\n",
    "                auto_cast_to_device=True,\n",
    "            )\n",
    "            # done = rollout.get((\"next\", \"done\"))\n",
    "            # y1 = rollout.get((\"next\", \"agents\", \"episode_reward\")).mean(-2)[done]\n",
    "            X = rollout.get(\"state\")[:, 0, : cfg.scenario.n_colors]\n",
    "\n",
    "            ep_reward = rollout.get((\"next\", \"agents\", \"reward\"))\n",
    "            ep_reward = ep_reward * discount\n",
    "            ep_reward = ep_reward.sum(dim=(1, 2, 3))\n",
    "\n",
    "            first_obs = rollout[:, 0]\n",
    "            expected_reward = (\n",
    "                critic(first_obs)[\"agents\", \"state_value\"].sum(dim=-2).squeeze()\n",
    "            )\n",
    "            # y = rollout.get((\"next\", \"agents\", \"episode_reward\")).mean(-2)[done]\n",
    "            # print(torch.stack([y, expected_reward], dim=-1))\n",
    "            # assert False\n",
    "\n",
    "            data = TensorDict(\n",
    "                {\n",
    "                    \"env\": X,\n",
    "                    \"episode_reward\": ep_reward,\n",
    "                    \"expected_reward\": expected_reward,\n",
    "                },\n",
    "                batch_size=len(ep_reward),\n",
    "            )\n",
    "            env_returns.extend(data)\n",
    "    del rollout\n",
    "\n",
    "env_returns_path = os.path.join(cache_dir, \"env_returns\")\n",
    "if RECOMPUTE:\n",
    "    env_returns.dumps(env_returns_path)\n",
    "else:\n",
    "    env_returns.loads(env_returns_path)\n",
    "\n",
    "\n",
    "class EnvReturnsDataset(Dataset):\n",
    "    def __init__(self, env_returns):\n",
    "        self.env_returns = env_returns\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.env_returns)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.env_returns[idx]\n",
    "        X = sample.get(\"env\").to(dtype=torch.float32, device=device)\n",
    "        y = sample.get(\"episode_reward\").to(dtype=torch.float32, device=device)\n",
    "        y_pred = sample.get(\"expected_reward\").to(dtype=torch.float32, device=device)\n",
    "        return X, y, y_pred\n",
    "\n",
    "\n",
    "env_returns_dataset = EnvReturnsDataset(env_returns)\n",
    "\n",
    "train_size = int(0.8 * len(env_returns))\n",
    "eval_size = len(env_returns) - train_size\n",
    "train_dataset, eval_dataset = torch.utils.data.random_split(\n",
    "    env_returns_dataset, [train_size, eval_size]\n",
    ")\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    (env_returns.storage[\"episode_reward\"] - env_returns.storage[\"expected_reward\"])\n",
    "    ** 2\n",
    ").mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_reward = env_returns_dataset[:][1]\n",
    "expected_reward = env_returns_dataset[:][2]\n",
    "torch.nn.functional.mse_loss(episode_reward, expected_reward)\n",
    "\n",
    "environments_episode_order = torch.argsort(episode_reward, descending=True)\n",
    "environments_expected_order = torch.argsort(expected_reward, descending=True)\n",
    "best_environments_episode = environments_episode_order[:5]\n",
    "best_environments_expected = environments_expected_order[:5]\n",
    "worst_environments_episode = environments_episode_order[-5:]\n",
    "worst_environments_expected = environments_expected_order[-5:]\n",
    "\n",
    "print(\"Best episode\")\n",
    "print(episode_reward[best_environments_episode])\n",
    "print(expected_reward[best_environments_episode])\n",
    "\n",
    "print(\"Best expected\")\n",
    "print(episode_reward[best_environments_expected])\n",
    "print(expected_reward[best_environments_expected])\n",
    "\n",
    "print(\"Worst episode\")\n",
    "print(episode_reward[worst_environments_episode])\n",
    "print(expected_reward[worst_environments_episode])\n",
    "\n",
    "print(\"Worst expected\")\n",
    "print(episode_reward[worst_environments_expected])\n",
    "print(expected_reward[worst_environments_expected])\n",
    "\n",
    "fig, axs = plt.subplots(4, 5)\n",
    "fig.set_size_inches(5 * FIGURE_SIZE_CNST, 4 * FIGURE_SIZE_CNST)\n",
    "\n",
    "for i, envs in enumerate(\n",
    "    (\n",
    "        best_environments_episode,\n",
    "        best_environments_expected,\n",
    "        worst_environments_episode,\n",
    "        worst_environments_expected,\n",
    "    )\n",
    "):\n",
    "    for j, idx in enumerate(envs):\n",
    "        ax = axs[i, j]\n",
    "        layout = storage_to_layout(\n",
    "            features=env_returns_dataset[:][0][idx].numpy(force=True),\n",
    "            config=cfg.scenario,\n",
    "            representation=\"image\",\n",
    "        )\n",
    "        warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "        im = warehouse.render()\n",
    "        ax.imshow(im)\n",
    "        warehouse.close()\n",
    "        ax.axis(\"off\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GUIDANCE_WT = 50\n",
    "# TRAIN_NUM_EPOCHS = 80\n",
    "# # VALUE_LR = 3e-4\n",
    "# VALUE_LR = 3e-5\n",
    "# VALUE_WEIGHT_DECAY = 0.05\n",
    "# RECOMPUTE = False\n",
    "# USE_GOAL_MAP = False\n",
    "\n",
    "# ITERATIONS_PER_EPOCH = math.ceil(DATASET_SIZE / BATCH_SIZE)\n",
    "\n",
    "# goal_map = np.zeros((cfg.scenario.n_colors, cfg.scenario.size, cfg.scenario.size))\n",
    "# for goal, color in zip(cfg.scenario.goal_idxs, cfg.scenario.goal_colors):\n",
    "#     goal_map[color, goal // cfg.scenario.size, goal % cfg.scenario.size] = 1\n",
    "# goal_map = (\n",
    "#     torch.from_numpy(goal_map).to(device=device, dtype=torch.float32).unsqueeze(0)\n",
    "# )\n",
    "\n",
    "# # Create value model\n",
    "# pretrain_dir = os.path.join(\n",
    "#     OUTPUT_DIR, \"diffusion_pretrain\", \"image\", cfg.scenario.name\n",
    "# )\n",
    "# latest_checkpoint = get_latest_model(pretrain_dir, \"model\")\n",
    "\n",
    "\n",
    "# generator = Generator(\n",
    "#     generator_model_path=latest_checkpoint,\n",
    "#     batch_size=8,\n",
    "#     scenario=cfg.scenario,\n",
    "#     guidance_wt=GUIDANCE_WT,\n",
    "#     representation=\"image\",\n",
    "# )\n",
    "\n",
    "# model_dict = classifier_defaults()\n",
    "# model_dict[\"image_size\"] = cfg.scenario.size\n",
    "# model_dict[\"image_channels\"] = (\n",
    "#     cfg.scenario.n_colors * 2 if USE_GOAL_MAP else cfg.scenario.n_colors\n",
    "# )\n",
    "\n",
    "# model_dict[\"classifier_width\"] = 128\n",
    "# model_dict[\"classifier_depth\"] = 2\n",
    "# model_dict[\"classifier_attention_resolutions\"] = \"16, 8, 4\"\n",
    "# model_dict[\"output_dim\"] = 1\n",
    "\n",
    "# model = create_classifier(**model_dict).to(device)\n",
    "\n",
    "# # Train\n",
    "# optim = torch.optim.Adam(\n",
    "#     model.parameters(), lr=VALUE_LR, weight_decay=VALUE_WEIGHT_DECAY\n",
    "# )\n",
    "# criterion = torch.nn.MSELoss()\n",
    "\n",
    "\n",
    "# if RECOMPUTE:\n",
    "#     train_losses = []\n",
    "#     eval_losses = []\n",
    "#     with tqdm(range(TRAIN_NUM_EPOCHS)) as pbar:\n",
    "#         for epoch in range(TRAIN_NUM_EPOCHS):\n",
    "#             running_train_loss = 0\n",
    "#             model.train()\n",
    "#             for X_batch, y_batch in train_loader:\n",
    "#                 optim.zero_grad()\n",
    "#                 X_batch = X_batch.to(dtype=torch.float32, device=device)\n",
    "#                 y_batch = y_batch.to(dtype=torch.float32, device=device)\n",
    "\n",
    "#                 # Add goal map\n",
    "#                 if USE_GOAL_MAP:\n",
    "#                     goal_map_batch = goal_map.expand(X_batch.shape[0], -1, -1, -1)\n",
    "#                     X_batch = torch.cat([X_batch, goal_map_batch], dim=1)\n",
    "\n",
    "#                 # Normalisation\n",
    "#                 X_batch = X_batch * 2 - 1\n",
    "#                 # t, _ = generator.schedule_sampler.sample(X_batch.shape[0], device)\n",
    "#                 # X_batch = generator.diffusion.q_sample(X_batch, t)\n",
    "#                 y_pred = model(X_batch).squeeze()\n",
    "#                 loss = criterion(y_pred, y_batch)\n",
    "#                 loss.backward()\n",
    "#                 optim.step()\n",
    "\n",
    "#                 running_train_loss += loss.item()\n",
    "#             running_train_loss = running_train_loss / len(train_loader)\n",
    "\n",
    "#             # Evaluate\n",
    "#             model.eval()\n",
    "#             running_eval_loss = 0\n",
    "#             with torch.no_grad():\n",
    "#                 for X_batch, y_batch in eval_loader:\n",
    "#                     X_batch = X_batch.to(dtype=torch.float32, device=device)\n",
    "#                     y_batch = y_batch.to(dtype=torch.float32, device=device)\n",
    "\n",
    "#                     # Add goal map\n",
    "#                     if USE_GOAL_MAP:\n",
    "#                         goal_map_batch = goal_map.expand(X_batch.shape[0], -1, -1, -1)\n",
    "#                         X_batch = torch.cat([X_batch, goal_map_batch], dim=1)\n",
    "\n",
    "#                     # Normalisation\n",
    "#                     X_batch = X_batch * 2 - 1\n",
    "\n",
    "#                     y_pred = model(X_batch).squeeze()\n",
    "#                     loss = criterion(y_pred, y_batch)\n",
    "\n",
    "#                     running_eval_loss += loss.item()\n",
    "#             running_eval_loss = running_eval_loss / len(eval_loader)\n",
    "\n",
    "#             train_losses.append(running_train_loss)\n",
    "#             eval_losses.append(running_eval_loss)\n",
    "#             pbar.set_description(\n",
    "#                 f\" Train Loss {running_train_loss} Eval Loss {running_eval_loss}\"\n",
    "#             )\n",
    "#             pbar.update()\n",
    "\n",
    "#     torch.save(model.state_dict(), \"train_visualisation_classifier.pt\")\n",
    "# else:\n",
    "#     model.load_state_dict(torch.load(\"train_visualisation_classifier.pt\"))\n",
    "\n",
    "print(sum([x.numel() for x in model.parameters()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try:\n",
    "#     train_losses\n",
    "#     plt.plot(train_losses)\n",
    "#     plt.plot(eval_losses)\n",
    "#     plt.title(\"Value Function loss\")\n",
    "#     min(train_losses)\n",
    "# except:\n",
    "#     pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "\n",
    "env_returns_sorted_index = torch.argsort(eval_dataset[:][1], descending=True)\n",
    "# env_returns_sorted_index = torch.argsort(\n",
    "#     env_returns.storage[\"episode_reward\"], descending=True\n",
    "# )\n",
    "best_5 = env_returns_sorted_index[:5]\n",
    "worst_5 = env_returns_sorted_index[-5:]\n",
    "\n",
    "fig, axs = plt.subplots(2, 5)\n",
    "fig.set_size_inches(5 * FIGURE_SIZE_CNST, 2 * FIGURE_SIZE_CNST)\n",
    "\n",
    "for i, idx in enumerate(best_5):\n",
    "    ax = axs[0, i]\n",
    "    layout = storage_to_layout(\n",
    "        # env_returns.storage[\"env\"][idx],\n",
    "        eval_dataset[:][0][idx].numpy(force=True),\n",
    "        cfg.scenario,\n",
    "    )\n",
    "    print(eval_dataset[:][1][idx])\n",
    "    warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "    im = warehouse.render()\n",
    "    ax.imshow(im)\n",
    "    warehouse.close()\n",
    "    ax.axis(\"off\")\n",
    "\n",
    "print(\"===\")\n",
    "\n",
    "for i, idx in enumerate(worst_5):\n",
    "    ax = axs[1, i]\n",
    "    layout = storage_to_layout(\n",
    "        # env_returns.storage[\"env\"][idx],\n",
    "        eval_dataset[:][0][idx].numpy(force=True),\n",
    "        cfg.scenario,\n",
    "    )\n",
    "    print(eval_dataset[:][1][idx])\n",
    "    warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "    im = warehouse.render()\n",
    "    ax.imshow(im)\n",
    "    warehouse.close()\n",
    "    ax.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "N = 50\n",
    "\n",
    "model.eval()\n",
    "# good_envs = env_returns.storage[\"env\"][env_returns_sorted_index[:N]].to(\n",
    "good_envs = eval_dataset[:][0][env_returns_sorted_index[:N]].to(\n",
    "    device=device, dtype=torch.float32\n",
    ")\n",
    "# goal_map_batch = goal_map.expand(good_envs.shape[0], -1, -1).unsqueeze(1)\n",
    "# good_envs = torch.cat([good_envs, goal_map_batch], dim=1)\n",
    "print(f\"Good environments: {model(good_envs * 2 - 1).mean()}\")\n",
    "\n",
    "bad_envs = eval_dataset[:][0][env_returns_sorted_index[-N:]].to(\n",
    "    device=device, dtype=torch.float32\n",
    ")\n",
    "# bad_envs = torch.cat([bad_envs, goal_map_batch], dim=1)\n",
    "print(f\"Bad environments: {model(bad_envs * 2 - 1).mean()}\")\n",
    "\n",
    "del good_envs\n",
    "del bad_envs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = create_classifier(**model_dict).to(device)\n",
    "latest_classifier = get_latest_model(checkpoint_dir, \"designer_\")\n",
    "model.load_state_dict(torch.load(latest_classifier))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def projection_constraint(x):\n",
    "    target = (13, 13, 12, 12)\n",
    "    B, C, H, W = x.shape\n",
    "    x_flat = x.view(B, C, -1)\n",
    "    mask = torch.full_like(x_flat, -1)\n",
    "\n",
    "    for c, k in enumerate(target):\n",
    "        _, indices = torch.topk(x_flat[:, c, :], k=k, dim=1)\n",
    "        mask[:, c, :].scatter_(1, indices, 1)\n",
    "\n",
    "    # for c, k in enumerate(target):\n",
    "    #     probs = torch.nn.functional.softmax(x_flat[:, c, :] * 100, dim=1)\n",
    "    #     indices = torch.multinomial(probs, num_samples=k, replacement=False)\n",
    "    #     mask[:, c, :].scatter_(1, indices, 1)\n",
    "\n",
    "    return mask.view(B, C, H, W)\n",
    "\n",
    "\n",
    "test = torch.rand(8, 4, 16, 16)\n",
    "out = projection_constraint(test)\n",
    "\n",
    "out[0, 2].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Generator(\n",
    "    batch_size=10,\n",
    "    generator_model_path=latest_checkpoint,\n",
    "    scenario=cfg.scenario,\n",
    "    guidance_wt=5000,\n",
    "    representation=\"image\",\n",
    ")\n",
    "model.eval()\n",
    "operation = OptimizerDetails()\n",
    "operation.num_recurrences = 64\n",
    "operation.backward_steps = 0\n",
    "operation.projection_constraint = projection_constraint\n",
    "\n",
    "# operation.operated_image = goal_map * 2 - 1\n",
    "\n",
    "\n",
    "def show_batch(environment_batch, n: int = 8):\n",
    "    layouts = []\n",
    "    for image in environment_batch:\n",
    "        layout = storage_to_layout(image, cfg.scenario)\n",
    "        warehouse = Warehouse(layout=layout, render_mode=\"rgb_array\")\n",
    "        layouts.append(warehouse.render())\n",
    "        warehouse.close()\n",
    "\n",
    "    fig, axs = plt.subplots(3, 3, figsize=(12, 12))\n",
    "    axs = axs.ravel()\n",
    "    for ax in axs:\n",
    "        ax.axis(\"off\")\n",
    "    for i in range(n):\n",
    "        axs[i].imshow(layouts[i])\n",
    "    return fig, axs\n",
    "\n",
    "\n",
    "environment_batch = generator.generate_batch(\n",
    "    value=model,\n",
    "    use_operation=True,\n",
    "    operation_override=operation,\n",
    ")\n",
    "\n",
    "\n",
    "for env in environment_batch:\n",
    "    layout = storage_to_layout(env, cfg.scenario)\n",
    "    print(len(layout.reset_shelves()))\n",
    "fig, axs = show_batch(environment_batch)\n",
    "fig.suptitle(\"Guided Generation\")\n",
    "fig.tight_layout()\n",
    "\n",
    "X_batch = (\n",
    "    torch.from_numpy(environment_batch).to(device=device, dtype=torch.float32)\n",
    "    # .moveaxis((0, 1, 2, 3), (0, 2, 3, 1))\n",
    ")\n",
    "# X_batch = torch.cat([X_batch, goal_map.unsqueeze(0).expand(8, -1, -1, -1)], dim=1)\n",
    "X_batch = (X_batch * 2) - 1\n",
    "print(model(X_batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Generator(\n",
    "    batch_size=8,\n",
    "    generator_model_path=latest_checkpoint,\n",
    "    size=cfg.scenario.size,\n",
    "    num_channels=cfg.scenario.n_colors,\n",
    "    guidance_wt=200,\n",
    ")\n",
    "model.eval()\n",
    "operation = OptimizerDetails()\n",
    "operation.num_recurrences = 8\n",
    "operation.backward_steps = 0\n",
    "\n",
    "generated_envs = []\n",
    "generator = Generator(\n",
    "    batch_size=50,\n",
    "    generator_model_path=latest_checkpoint,\n",
    "    size=cfg.scenario.size,\n",
    "    num_channels=cfg.scenario.n_colors,\n",
    "    guidance_wt=200,\n",
    ")\n",
    "for _ in tqdm(range(20)):\n",
    "    environment_batch = generator.generate_batch(\n",
    "        value=model, use_operation=True, operation_override=operation\n",
    "    )\n",
    "    generated_envs.extend(environment_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_envs_numpy = np.array(generated_envs)\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        s = generated_envs_numpy.sum(axis=0)\n",
    "        axs[i, j].imshow(s[i * 2 + j])\n",
    "        axs[i, j].axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get best 9 environments using the classifier\n",
    "model.eval()\n",
    "X_batch = torch.from_numpy(generated_envs_numpy).to(device=device, dtype=torch.float32)\n",
    "X_batch = (X_batch * 2) - 1  # Normalize the input\n",
    "scores = model(X_batch).squeeze()\n",
    "best_9_indices = torch.topk(scores, 9).indices\n",
    "generated_envs_numpy_best = generated_envs_numpy[best_9_indices.cpu().numpy()]\n",
    "print(scores[best_9_indices])\n",
    "\n",
    "print(scores.min().item(), scores.max().item(), scores.std().item())\n",
    "\n",
    "for env in generated_envs_numpy_best:\n",
    "    layout = storage_to_layout(\n",
    "        env, cfg.scenario.agent_idxs, cfg.scenario.goal_idxs, cfg.scenario.goal_colors\n",
    "    )\n",
    "    print(len(layout.reset_shelves()))\n",
    "fig, axs = show_batch(generated_envs_numpy_best)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(generated_envs_numpy.sum(axis=0).flatten(), bins=50)\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Counterfactual: randomly generated environments\n",
    "\n",
    "environment_batch = generator.generate_batch()\n",
    "\n",
    "\n",
    "for env in environment_batch:\n",
    "    layout = storage_to_layout(\n",
    "        env, cfg.scenario.agent_idxs, cfg.scenario.goal_idxs, cfg.scenario.goal_colors\n",
    "    )\n",
    "    print(len(layout.reset_shelves()))\n",
    "fig, axs = show_batch(environment_batch)\n",
    "fig.suptitle(\"Guided Generation\")\n",
    "fig.tight_layout()\n",
    "\n",
    "X_batch = (\n",
    "    torch.from_numpy(environment_batch).to(device=device, dtype=torch.float32)\n",
    "    # .moveaxis((0, 1, 2, 3), (0, 2, 3, 1))\n",
    ")\n",
    "# X_batch = torch.cat([X_batch, goal_map.unsqueeze(0).expand(8, -1, -1, -1)], dim=1)\n",
    "X_batch = (X_batch * 2) - 1\n",
    "print(model(X_batch))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
