{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee5f110",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from collections import defaultdict\n",
    "import torch\n",
    "import wandb\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib as mpl\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.patches import Patch\n",
    "from datetime import datetime\n",
    "import seaborn as sns\n",
    "\n",
    "from diffusion_co_design.common import OUTPUT_DIR, get_latest_model, cuda\n",
    "from diffusion_co_design.rware.diffusion.generate import generate\n",
    "from diffusion_co_design.rware.env import render_env as render_rware_env\n",
    "from diffusion_co_design.rware.design import (\n",
    "    DicodeDesigner as RwareDicodeDesigner,\n",
    "    DescentDesigner,\n",
    ")\n",
    "from diffusion_co_design.rware.diffusion.transform import (\n",
    "    storage_to_layout,\n",
    "    storage_to_layout_image,\n",
    "    graph_projection_constraint,\n",
    "    train_to_eval,\n",
    ")\n",
    "from diffusion_co_design.rware.diffusion.generator import (\n",
    "    Generator as RwareGenerator,\n",
    "    OptimizerDetails,\n",
    ")\n",
    "from diffusion_co_design.rware.model.classifier import make_model\n",
    "from diffusion_co_design.rware.schema import (\n",
    "    ScenarioConfig as RwareScenarioConfig,\n",
    "    TrainingConfig as RwareTrainingConfig,\n",
    ")\n",
    "from wfcrl.environments.data_cases import floris_ormonde\n",
    "from diffusion_co_design.wfcrl.schema import (\n",
    "    ScenarioConfig as WfcrlScenarioConfig,\n",
    "    TrainingConfig as WfcrlTrainingConfig,\n",
    ")\n",
    "from diffusion_co_design.wfcrl.design import DicodeDesigner\n",
    "from diffusion_co_design.wfcrl.diffusion.generate import Generate\n",
    "from diffusion_co_design.wfcrl.env import _create_designable_windfarm, render_layout\n",
    "from rware.warehouse import Warehouse\n",
    "\n",
    "\n",
    "# Wandb limits to 500\n",
    "def get_full_history(run, key):\n",
    "    values = []\n",
    "    for row in run.scan_history(keys=[key]):\n",
    "        values.append(row[key])\n",
    "    return np.array(values)\n",
    "\n",
    "\n",
    "def ema(data: np.ndarray, alpha: float = 0.95):\n",
    "    ema = np.zeros_like(data)\n",
    "\n",
    "    ema[0] = data[0]\n",
    "    for i in range(1, data.shape[0]):\n",
    "        ema[i] = alpha * ema[i - 1] + (1 - alpha) * data[i]\n",
    "\n",
    "    return ema\n",
    "\n",
    "\n",
    "device = cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66b0735",
   "metadata": {},
   "outputs": [],
   "source": [
    "shelf_im = generate(size=8, n_shelves=20, goal_idxs=[0, 7, 55, 63], n_colors=4)[0]\n",
    "layout = storage_to_layout_image(\n",
    "    shelf_im,\n",
    "    agent_idxs=[12, 23, 54, 8],\n",
    "    agent_colors=[-1, -1, -1, -1],\n",
    "    goal_idxs=[0, 7, 56, 63],\n",
    "    goal_colors=[0, 1, 2, 3],\n",
    ")\n",
    "warehouse = Warehouse(layout=layout)\n",
    "image = warehouse.render()\n",
    "image = image[::2, ::2]\n",
    "warehouse.close()\n",
    "\n",
    "H, W, C = image.shape\n",
    "\n",
    "noise = np.random.normal(loc=0.0, scale=1.0, size=image.shape)\n",
    "noise = (noise - noise.min()) / (noise.max() - noise.min())\n",
    "noise = noise * 255\n",
    "noise = noise.astype(np.uint8)\n",
    "\n",
    "for i, beta in enumerate([0, 0.5, 0.75, 1]):\n",
    "    blended = ((1 - beta) * image + beta * noise).astype(np.uint8)\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    plt.imshow(blended)\n",
    "    plt.axis(\"off\")\n",
    "    plt.savefig(f\"blended_{i}.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59336abe",
   "metadata": {},
   "outputs": [],
   "source": [
    "warehouse = Warehouse(layout=layout)\n",
    "image = warehouse.render()\n",
    "warehouse.close()\n",
    "plt.imshow(image)\n",
    "plt.axis(\"off\")\n",
    "plt.savefig(\"d-rware\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322069ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "scenario = RwareScenarioConfig(\n",
    "    name=\"d-rware-example\",\n",
    "    n_agents=3,\n",
    "    n_shelves=16,\n",
    "    n_colors=4,\n",
    "    goal_idxs=[0, 7, 56, 63],\n",
    "    agent_idxs=[12, 23, 54, 8],\n",
    "    agent_colors=[-1, -1, -1, -1],\n",
    "    n_goals=4,\n",
    "    goal_colors=[0, 1, 2, 3],\n",
    "    max_steps=100,\n",
    "    size=8,\n",
    ")\n",
    "\n",
    "\n",
    "x0 = []\n",
    "for x in range(2, 6):\n",
    "    for y in range(2, 6):\n",
    "        x0.append((x, y))\n",
    "x0 = np.array(x0)\n",
    "\n",
    "\n",
    "def plot_points(layout, scenario: RwareScenarioConfig):\n",
    "    fig, ax = plt.subplots(1, 1)\n",
    "    x = [p[0] for p in layout]\n",
    "    y = [p[1] for p in layout]\n",
    "\n",
    "    ax.scatter(x, y, s=130, color=\"blue\", edgecolors=\"black\", linewidths=0.5, zorder=3)\n",
    "    ax.axis(\"off\")\n",
    "\n",
    "    # Grid lines\n",
    "    for x in range(scenario.size):\n",
    "        ax.plot(\n",
    "            [x, x],\n",
    "            [0, scenario.size - 1],\n",
    "            color=\"gray\",\n",
    "            linewidth=1,\n",
    "        )\n",
    "\n",
    "    for y in range(scenario.size):\n",
    "        ax.plot(\n",
    "            [0, scenario.size - 1],\n",
    "            [y, y],\n",
    "            color=\"gray\",\n",
    "            linewidth=1,\n",
    "        )\n",
    "\n",
    "    ax.set_aspect(\"equal\")\n",
    "    fig.set_tight_layout(True)\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "# x0\n",
    "fig, axs = plot_points(x0, scenario=scenario)\n",
    "fig.savefig(\"pug_0.png\", bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "# xT\n",
    "xT = np.clip(\n",
    "    (np.random.normal(0, 1.0, size=x0.shape) + 1) / 2 * (scenario.size - 1),\n",
    "    0,\n",
    "    (scenario.size - 1),\n",
    ")\n",
    "fig, axs = plot_points(xT, scenario=scenario)\n",
    "fig.savefig(\"pug_T.png\", bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "# xt\n",
    "alpha = 0.7\n",
    "xt = alpha**0.4 * x0 + (1 - alpha) * xT\n",
    "fig, axs = plot_points(xt, scenario=scenario)\n",
    "fig.savefig(\"pug_t.png\", bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "# xt_0\n",
    "xt_0 = x0 + np.random.normal(0, 0.3, size=x0.shape)\n",
    "fig, axs = plot_points(xt_0, scenario=scenario)\n",
    "fig.savefig(\"pug_t0.png\", bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "# xt constrained\n",
    "xt_constr = graph_projection_constraint(scenario)(\n",
    "    torch.tensor(xt_0 / (scenario.size - 1) * 2 - 1, dtype=torch.float32).unsqueeze(0)\n",
    ")[0].numpy()\n",
    "xt_constr = (xt_constr + 1) / 2 * (scenario.size - 1)\n",
    "xt_constr = np.round(xt_constr)\n",
    "fig, axs = plot_points(xt_constr, scenario=scenario)\n",
    "fig.savefig(\"pug_t_constr.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09a41678",
   "metadata": {},
   "outputs": [],
   "source": [
    "xcoords = np.array(floris_ormonde.xcoords)\n",
    "ycoords = np.array(floris_ormonde.ycoords)\n",
    "margin = 300\n",
    "\n",
    "xcoords = xcoords - xcoords.min() + margin\n",
    "ycoords = ycoords - ycoords.min() + margin\n",
    "\n",
    "scenario = WfcrlScenarioConfig(\n",
    "    name=\"ormonde_render_example\",\n",
    "    n_turbines=len(xcoords),\n",
    "    max_steps=margin,\n",
    "    map_x_length=int(xcoords.max() + margin),\n",
    "    map_y_length=int(ycoords.max() + margin),\n",
    "    min_distance_between_turbines=400,\n",
    ")\n",
    "\n",
    "env = _create_designable_windfarm(\n",
    "    scenario=scenario,\n",
    "    initial_xcoords=xcoords.tolist(),\n",
    "    initial_ycoords=ycoords.tolist(),\n",
    "    render=True,\n",
    ")\n",
    "\n",
    "# Take some random steps\n",
    "env.reset()\n",
    "for _ in range(2000):\n",
    "    env.step({\"yaw\": np.array([np.random.rand() * 10 - 5])})\n",
    "\n",
    "fig, axs = plt.subplots(figsize=(6, 6))\n",
    "axs.axis(\"off\")\n",
    "axs.imshow(env.render(), aspect=\"auto\")\n",
    "\n",
    "fig.savefig(\"wfcrl_ormonde.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d40de3d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# D-RWARE Corners Plot\n",
    "project_name = \"diffusion-co-design-rware-rware_16_50_5_4_corners\"\n",
    "api = wandb.Api()\n",
    "runs = api.runs(path=project_name)\n",
    "\n",
    "total_steps = 4000\n",
    "runs_dict = defaultdict(list)\n",
    "train_reward_key = \"train/reward/episode_reward_mean\"\n",
    "\n",
    "\n",
    "for run in tqdm(runs):\n",
    "    name = run.name\n",
    "    cfg = run.config\n",
    "    reward = get_full_history(run, train_reward_key)\n",
    "\n",
    "    if run.created_at < datetime(2025, 7, 3):\n",
    "        # Old config version with different keys\n",
    "        d_loss = get_full_history(run, \"train/designer_loss\")\n",
    "        d_min = get_full_history(run, \"train/design_y_min\")\n",
    "        d_max = get_full_history(run, \"train/design_y_max\")\n",
    "    else:\n",
    "        d_loss = get_full_history(run, \"train/designer/prediction_loss\")\n",
    "        d_min = get_full_history(run, \"train/designer/train_y_min\")\n",
    "        d_max = get_full_history(run, \"train/designer/train_y_max\")\n",
    "\n",
    "    runs_dict[name].append(\n",
    "        {\n",
    "            \"cfg\": cfg,\n",
    "            \"reward\": reward,\n",
    "            \"d_loss\": d_loss,\n",
    "            \"d_min\": d_min,\n",
    "            \"d_max\": d_max,\n",
    "            \"designer_artifact_path\": None,\n",
    "        }\n",
    "    )\n",
    "    run_date = str(run.created_at)\n",
    "    for artifact in run.logged_artifacts():\n",
    "        if artifact.name.startswith(\"designer_final\"):\n",
    "            path = f\"artifacts/{name}/{run_date}\"\n",
    "            runs_dict[name][-1][\"designer_artifact_path\"] = path\n",
    "            if os.path.exists(path):\n",
    "                continue\n",
    "            assert artifact.download(root=f\"artifacts/{name}/{run_date}\") == path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eaf0e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(context=\"notebook\")\n",
    "fig, axs = plt.subplots(1, 1)\n",
    "fig.set_size_inches(4.0, 4.0 / 1.618)\n",
    "\n",
    "key_to_label_map = {\n",
    "    \"corners_agent_distill_image\": \"DiCoDe (ours)\",\n",
    "    \"corners_agent_distill_gnn\": \"DiCoDe-Coord (ours)\",\n",
    "    \"corners_agent_image\": \"DiCoDe-MC\",\n",
    "    \"corners_agent_image_add\": \"DiCoDe-ADD\",\n",
    "    \"corners_agent_descent\": \"DiCoDe-Descent\",\n",
    "    \"corners_agent_sampling\": \"DiCoDe-Sampling\",\n",
    "    \"corners_agent_fixed\": \"Fixed\",\n",
    "    \"corners_agent_random\": \"DR\",\n",
    "    \"corners_agent_rl\": \"RL\",\n",
    "}\n",
    "total_training_iterations = 4000\n",
    "samples_per_iteration = 5000\n",
    "colors = sns.color_palette(n_colors=len(key_to_label_map))\n",
    "\n",
    "\n",
    "for (key, label), color in zip(key_to_label_map.items(), colors):\n",
    "    runs = runs_dict[key]\n",
    "\n",
    "    rewards = []\n",
    "    for x in runs:\n",
    "        reward = x[\"reward\"]\n",
    "        if len(x[\"reward\"]) != total_training_iterations and label != \"RL\":\n",
    "            # Run not complete, skip\n",
    "            continue\n",
    "        reward = ema(x[\"reward\"])\n",
    "        rewards.append(reward)\n",
    "    rewards = np.array(rewards)\n",
    "\n",
    "    if rewards.shape[0] == 0:\n",
    "        continue\n",
    "\n",
    "    if label == \"RL\":\n",
    "        X = np.linspace(1, 2000 + 1, 2000)\n",
    "        X = X * (samples_per_iteration + 10000)\n",
    "        X = X[:1333]  # Too many samples\n",
    "\n",
    "        rl_rewards = rewards\n",
    "        rewards = rewards[:, :1333]\n",
    "\n",
    "        print(\n",
    "            \"RL Full\",\n",
    "            rl_rewards.mean(axis=0)[-1],\n",
    "            rl_rewards.std(axis=0)[-1] / np.sqrt(rl_rewards.shape[0]) * 1.96,\n",
    "        )\n",
    "\n",
    "    else:\n",
    "        X = (\n",
    "            np.linspace(1, total_training_iterations + 1, total_training_iterations)\n",
    "            * samples_per_iteration\n",
    "        )\n",
    "\n",
    "    mu = rewards.mean(axis=0)\n",
    "    print(label, f\"mean: {mu[-1]}\")\n",
    "    dashed = True if label[-6:] != \"(ours)\" else False\n",
    "    axs.plot(X, mu, color=color, label=label, linestyle=\"--\" if dashed else \"-\")\n",
    "    if rewards.shape[0] > 1:\n",
    "        std_err = rewards.std(axis=0) / np.sqrt(rewards.shape[0]) * 1.96\n",
    "        print(f\"95%: {std_err[-1]}\")\n",
    "        axs.fill_between(X, y1=mu - std_err, y2=mu + std_err, color=color, alpha=0.3)\n",
    "    pass\n",
    "\n",
    "axs.set_title(\"D-RWARE (Corner) Training Progress\")\n",
    "axs.set_xlabel(\"Frames\")\n",
    "axs.set_ylabel(\"Episode Reward\")\n",
    "\n",
    "legend_fig = plt.figure(figsize=(4.5, 0.5))  # ~2/3 paper width\n",
    "legend_fig.legend(\n",
    "    handles=axs.get_legend_handles_labels()[0],\n",
    "    labels=axs.get_legend_handles_labels()[1],\n",
    "    loc=\"center\",\n",
    "    ncol=len(key_to_label_map),\n",
    "    frameon=False,\n",
    "    fontsize=9,\n",
    ")\n",
    "\n",
    "\n",
    "fig.savefig(fname=\"dicode-corners.png\", bbox_inches=\"tight\", dpi=300)\n",
    "legend_fig.savefig(\"dicode-corners-legend.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ca4d720",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(5.0, 2.3))\n",
    "# fig.suptitle(\"Training Targets for Environment Critic\")\n",
    "\n",
    "\n",
    "def plot_mu_and_std(ax, X, data, color, label=None, ls=\"-\"):\n",
    "    mu = data.mean(axis=0)\n",
    "    std = data.std(axis=0)\n",
    "    ax.plot(X, mu, color=color, label=label, ls=ls)\n",
    "    ax.fill_between(X, y1=mu - std, y2=mu + std, color=color, alpha=0.3)\n",
    "\n",
    "\n",
    "colors = sns.color_palette(n_colors=2)\n",
    "\n",
    "for i, (key, label) in enumerate(\n",
    "    (\n",
    "        (\"corners_agent_distill_image\", \"Critic Distillation\"),\n",
    "        (\"corners_agent_image\", \"Sampled Trajectory Returns\"),\n",
    "    )\n",
    "):\n",
    "    runs = runs_dict[key]\n",
    "    c = colors[i]\n",
    "\n",
    "    d_loss = []\n",
    "    d_min = []\n",
    "    d_max = []\n",
    "    for x in runs:\n",
    "        if len(x[\"d_loss\"]) == 3995:\n",
    "            d_loss.append(ema(x[\"d_loss\"]))\n",
    "            d_min.append(ema(x[\"d_min\"]))\n",
    "            d_max.append(ema(x[\"d_max\"]))\n",
    "\n",
    "    d_loss = np.array(d_loss)\n",
    "    d_min = np.array(d_min)\n",
    "    d_max = np.array(d_max)\n",
    "\n",
    "    axs[0].set_title(\"Critic Training Target\")\n",
    "    axs[0].set_xlabel(\"Training Step\")\n",
    "    axs[0].set_ylabel(\"Value\")\n",
    "    plot_mu_and_std(\n",
    "        ax=axs[0], X=range(d_loss.shape[1]), data=d_min, color=c, label=label, ls=\"--\"\n",
    "    )\n",
    "    plot_mu_and_std(ax=axs[0], X=range(d_loss.shape[1]), data=d_max, color=c)\n",
    "\n",
    "    axs[1].set_title(\"Critic Training Loss\")\n",
    "    axs[1].set_xlabel(\"Training Step\")\n",
    "    # axs[1].set_ylabel(\"Value\")\n",
    "    plot_mu_and_std(ax=axs[1], X=range(d_loss.shape[1]), data=d_loss, color=c)\n",
    "\n",
    "\n",
    "# Custom legend: just colored patches\n",
    "handles = [\n",
    "    Patch(\n",
    "        color=colors[i],\n",
    "        label={\n",
    "            \"corners_agent_distill_image\": \"Critic Distillation\",\n",
    "            \"corners_agent_image\": \"Sampled Trajectory Returns\",\n",
    "        }[key],\n",
    "    )\n",
    "    for i, key in enumerate((\"corners_agent_distill_image\", \"corners_agent_image\"))\n",
    "]\n",
    "fig.legend(\n",
    "    handles=handles,\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.5, 0.02),\n",
    "    ncol=2,\n",
    "    frameon=False,\n",
    ")\n",
    "fig.set_tight_layout(True)\n",
    "fig.savefig(fname=\"ablation-distill-training.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d3b0f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_dict_processed: dict[\n",
    "    str, tuple[RwareScenarioConfig, RwareTrainingConfig, list]\n",
    "] = {}\n",
    "scenario = None\n",
    "for key, label in key_to_label_map.items():\n",
    "    cfg = runs_dict[key][0]\n",
    "    if scenario is None:\n",
    "        scenario = RwareScenarioConfig.from_raw(cfg[\"cfg\"][\"scenario\"])\n",
    "    train_cfg = RwareTrainingConfig.from_raw(cfg[\"cfg\"])\n",
    "\n",
    "    repeats = [\n",
    "        run[\"designer_artifact_path\"]\n",
    "        for run in runs_dict[key]\n",
    "        if run[\"designer_artifact_path\"] is not None\n",
    "    ]\n",
    "    runs_dict_processed[label] = (scenario, train_cfg, repeats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ec99584",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PUG Ablation\n",
    "scenario, train_cfg, repeats = runs_dict_processed[\"DiCoDe-Coord\"]\n",
    "diffusion_dir = pretrain_dir = os.path.join(\n",
    "    OUTPUT_DIR, \"rware\", \"diffusion\", \"graph\", scenario.name\n",
    ")\n",
    "latest_diffusion_checkpoint = get_latest_model(diffusion_dir, \"model\")\n",
    "state_dict = torch.load(\n",
    "    os.path.join(repeats[0], \"designer_3999.pt\"), map_location=device\n",
    ")\n",
    "\n",
    "model = make_model(\n",
    "    model=train_cfg.designer.model.name,\n",
    "    scenario=scenario,\n",
    "    model_kwargs=train_cfg.designer.model.model_kwargs,\n",
    "    device=device,\n",
    ")\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "\n",
    "N = 32\n",
    "\n",
    "# PUG\n",
    "generator = RwareGenerator(\n",
    "    batch_size=N,\n",
    "    generator_model_path=latest_diffusion_checkpoint,\n",
    "    scenario=scenario,\n",
    "    representation=\"graph\",\n",
    "    device=device,\n",
    ")\n",
    "guidance_model = model\n",
    "guidance_model.eval()\n",
    "\n",
    "operation = OptimizerDetails()\n",
    "operation.lr = 0.01\n",
    "operation.num_recurrences = 8\n",
    "operation.backward_steps = 16\n",
    "operation.forward_guidance_wt = 5.0\n",
    "operation.projection_constraint = graph_projection_constraint(scenario)\n",
    "\n",
    "pug_batch = np.array(\n",
    "    generator.generate_batch(\n",
    "        value=guidance_model,\n",
    "        use_operation=True,\n",
    "        operation_override=operation,\n",
    "    )\n",
    ")\n",
    "\n",
    "# UG\n",
    "ug_batch = None\n",
    "ug_batch = np.array(\n",
    "    generator.generate_batch(\n",
    "        value=guidance_model,\n",
    "        use_operation=True,\n",
    "        operation_override=operation,\n",
    "    )\n",
    ")\n",
    "\n",
    "# Descent\n",
    "grad_designer = DescentDesigner.make_placeholder(\n",
    "    scenario=scenario,\n",
    "    representation=\"graph\",\n",
    "    classifier=train_cfg.designer.model,\n",
    "    n_epochs=32,\n",
    "    n_gradient_iterations=10,\n",
    "    lr=0.03,\n",
    "    device=device,\n",
    ").design_producer\n",
    "grad_designer.value_learner.model = model\n",
    "descent_batch = np.array(\n",
    "    [theta.numpy(force=True) for theta in grad_designer.generate_layout_batch(N)]\n",
    ")\n",
    "\n",
    "\n",
    "# Sampling\n",
    "K = 32\n",
    "with torch.no_grad():\n",
    "    u_x = generate(\n",
    "        size=scenario.size,\n",
    "        n_shelves=scenario.n_shelves,\n",
    "        goal_idxs=scenario.goal_idxs,\n",
    "        n_colors=scenario.n_colors,\n",
    "        training_dataset=True,\n",
    "        representation=\"graph\",\n",
    "        n=N * K,\n",
    "    )\n",
    "\n",
    "    u_x = torch.tensor(np.array(u_x), device=device)\n",
    "\n",
    "    y = model.predict_theta_value(u_x)\n",
    "    u_x = u_x.reshape(N, K, scenario.n_shelves, 2)\n",
    "    y = y.reshape(N, K)\n",
    "    _, best_idxs = y.max(dim=-1)\n",
    "    sampling_batch = u_x[torch.arange(N), best_idxs]\n",
    "    sampling_batch = train_to_eval(sampling_batch, scenario, \"graph\")\n",
    "    sampling_batch = sampling_batch.numpy(force=True)\n",
    "    del u_x, best_idxs, y\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "random_batch = np.array(\n",
    "    generate(\n",
    "        size=scenario.size,\n",
    "        n_shelves=scenario.n_shelves,\n",
    "        goal_idxs=scenario.goal_idxs,\n",
    "        n_colors=scenario.n_colors,\n",
    "        training_dataset=False,\n",
    "        representation=\"graph\",\n",
    "        n=N,\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce876255",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "labels = []\n",
    "exp_returns = []\n",
    "\n",
    "\n",
    "selected_envs = {}\n",
    "\n",
    "for label, batch in (\n",
    "    (\"PUG\", pug_batch),\n",
    "    (\"UG\", ug_batch),\n",
    "    (\"Descent\", descent_batch),\n",
    "    (\"Sampling\", sampling_batch),\n",
    "    (\"DR\", random_batch),\n",
    "):\n",
    "    # Eval to train, run through model\n",
    "    x = torch.tensor(batch, device=device)\n",
    "    x = x / (scenario.size - 1) * 2 - 1\n",
    "    with torch.no_grad():\n",
    "        y = model.predict_theta_value(x).numpy(force=True)\n",
    "        print(label, y.mean().item(), y.max().item())\n",
    "\n",
    "    labels.append(label)\n",
    "    exp_returns.append(y)\n",
    "    selected_envs[label] = {\"best_idx\": y.argmax(), \"worst_idx\": y.argmin()}\n",
    "\n",
    "colors = sns.color_palette(n_colors=len(exp_returns))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5, 2.6))\n",
    "box = ax.boxplot(\n",
    "    exp_returns,\n",
    "    patch_artist=True,\n",
    "    labels=labels,\n",
    "    boxprops=dict(linewidth=1.2),\n",
    "    medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "    whiskerprops=dict(color=\"gray\"),\n",
    "    capprops=dict(color=\"gray\"),\n",
    ")\n",
    "\n",
    "# Apply colors\n",
    "for patch, color in zip(box[\"boxes\"], colors):\n",
    "    patch.set_facecolor(color)\n",
    "    patch.set_edgecolor(\"black\")\n",
    "\n",
    "\n",
    "ax.set_title(\"Environment Search Comparison\")\n",
    "ax.set_ylabel(\"Critic Value\")\n",
    "ax.set_xlabel(\"Generator Method\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(fname=\"ablation-pug-box.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77019fce",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(9, 6))\n",
    "\n",
    "for i, (label, batch) in enumerate(\n",
    "    (\n",
    "        (\"PUG\", pug_batch),\n",
    "        (\"Descent\", descent_batch),\n",
    "        (\"Sampling\", sampling_batch),\n",
    "    )\n",
    "):\n",
    "    best_idx = selected_envs[label][\"best_idx\"]\n",
    "    best_env = batch[best_idx]\n",
    "\n",
    "    ax = axs[i]\n",
    "    ax.imshow(render_rware_env(best_env, scenario, \"graph\"))\n",
    "    ax.set_title(f\"{label}\", fontsize=10)\n",
    "    ax.axis(\"off\")\n",
    "\n",
    "fig.savefig(fname=\"ablation-envs.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3e80a44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate heatmaps\n",
    "N = 100\n",
    "\n",
    "label_to_heatmap = {}\n",
    "for label, representation in [(\"DiCoDe\", \"image\"), (\"DiCoDe-Coord\", \"graph\")]:\n",
    "    scenario, train_cfg, repeats = runs_dict_processed[label]\n",
    "\n",
    "    designer = RwareDicodeDesigner.make_placeholder(\n",
    "        scenario=scenario,\n",
    "        classifier=train_cfg.designer.model,\n",
    "        diffusion=train_cfg.designer.diffusion,\n",
    "        representation=representation,\n",
    "    ).design_producer\n",
    "\n",
    "    B = N // len(repeats)\n",
    "    envs = []\n",
    "    for checkpoint_dir in tqdm(repeats):\n",
    "        print(checkpoint_dir)\n",
    "        designer.model.load_state_dict(\n",
    "            torch.load(\n",
    "                os.path.join(checkpoint_dir, \"designer_3999.pt\"), map_location=device\n",
    "            )\n",
    "        )\n",
    "\n",
    "        batch = designer.generate_layout_batch(batch_size=B)\n",
    "        for i, env in enumerate(batch):\n",
    "            layout = storage_to_layout(\n",
    "                features=env, config=scenario, representation=representation\n",
    "            )\n",
    "            envs.append(layout.storage)\n",
    "    heatmap = np.stack(envs).sum(axis=0)\n",
    "    label_to_heatmap[label] = heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918e123b",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = (\"Teal\", \"Purple\", \"Blue\", \"Green\")\n",
    "for label, heatmap in label_to_heatmap.items():\n",
    "    fig, axs = plt.subplots(2, 2, figsize=(3, 3))\n",
    "\n",
    "    color_min = 0\n",
    "    color_max = np.max(heatmap)\n",
    "\n",
    "    for i, c in enumerate(colors):\n",
    "        ax = axs[i // 2][i % 2]\n",
    "        im = ax.imshow(\n",
    "            heatmap[i], cmap=\"viridis\", aspect=\"equal\", vmin=color_min, vmax=color_max\n",
    "        )\n",
    "        ax.set_title(c)\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "    # leave space for colorbar\n",
    "    fig.subplots_adjust(top=0.95, right=0.85)\n",
    "\n",
    "    # add colorbar\n",
    "    cbar_ax = fig.add_axes([0.88, 0.20, 0.03, 0.7])\n",
    "    fig.colorbar(im, cax=cbar_ax)\n",
    "\n",
    "    # vertical title to the right of the colorbar\n",
    "    fig.text(\n",
    "        1.03,\n",
    "        0.75,  # x, y in figure coords\n",
    "        label,\n",
    "        rotation=-90,\n",
    "        va=\"top\",\n",
    "        ha=\"left\",\n",
    "        fontsize=14,\n",
    "    )\n",
    "\n",
    "    fig.savefig(fname=f\"{label}_heatmap.png\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfc6c765",
   "metadata": {},
   "outputs": [],
   "source": [
    "# WFCRL Plot\n",
    "api = wandb.Api()\n",
    "\n",
    "wfcrl_all_scenarios = {}\n",
    "\n",
    "for turbine_number, project_name in [\n",
    "    (2, \"diffusion-co-design-wfcrl-wfcrl_2\"),\n",
    "    (4, \"diffusion-co-design-wfcrl-wfcrl_4\"),\n",
    "    (8, \"diffusion-co-design-wfcrl-wfcrl_8\"),\n",
    "]:\n",
    "    runs = api.runs(path=project_name)\n",
    "\n",
    "    total_steps = 301\n",
    "    wfcrl_runs_dict = defaultdict(list)\n",
    "    train_reward_key = \"train/reward/episode_reward_mean\"\n",
    "\n",
    "    for run in tqdm(runs):\n",
    "        name = run.name\n",
    "        cfg = run.config\n",
    "        reward = get_full_history(run, train_reward_key)\n",
    "        run_data = {\"cfg\": cfg, \"reward\": reward}\n",
    "\n",
    "        run_data[\"designer_state_dict\"] = None\n",
    "        for artifact in run.logged_artifacts():\n",
    "            if artifact.name.startswith(\"designer_final\"):\n",
    "                artifact_dir = artifact.download()\n",
    "                state_dict = torch.load(\n",
    "                    os.path.join(artifact_dir, \"designer_300.pt\"), map_location=device\n",
    "                )\n",
    "                run_data[\"designer_state_dict\"] = state_dict\n",
    "\n",
    "        wfcrl_runs_dict[name].append(run_data)\n",
    "\n",
    "    wfcrl_all_scenarios[turbine_number] = wfcrl_runs_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52b221d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Rewards\n",
    "wfcrl_results = {}\n",
    "\n",
    "for turbine_number, wfcrl_runs_dict in wfcrl_all_scenarios.items():\n",
    "    print(f\"Turbines: {turbine_number}\")\n",
    "    for key, runs in wfcrl_runs_dict.items():\n",
    "        rewards = []\n",
    "        for x in runs:\n",
    "            reward = ema(x[\"reward\"])\n",
    "            if reward.shape[0] != 301:\n",
    "                continue\n",
    "            rewards.append(reward)\n",
    "        rewards = np.array(rewards)\n",
    "        mu = rewards.mean(axis=0)\n",
    "        print(key, f\"mean: {mu[-1]}\")\n",
    "        std_err = rewards.std(axis=0) / np.sqrt(rewards.shape[0]) * 1.96\n",
    "        print(f\"95%: {std_err[-1]}\")\n",
    "\n",
    "        wfcrl_results[(turbine_number, key)] = {\"mu\": mu[-1], \"std_err\": std_err[-1]}\n",
    "\n",
    "algos = [\n",
    "    \"wfcrl_diffusion_distill\",\n",
    "    \"wfcrl_reinforce\",\n",
    "    \"wfcrl_fixed\",\n",
    "]  # DiCoDe, Reinforce, Fixed\n",
    "labels = {\n",
    "    \"wfcrl_diffusion_distill\": \"DiCoDe\",\n",
    "    \"wfcrl_reinforce\": \"RL\",\n",
    "    \"wfcrl_fixed\": \"Fixed\",\n",
    "}\n",
    "colors = {\n",
    "    \"wfcrl_diffusion_distill\": \"tab:blue\",\n",
    "    \"wfcrl_reinforce\": \"tab:orange\",\n",
    "    \"wfcrl_fixed\": \"tab:green\",\n",
    "}\n",
    "\n",
    "turbines = sorted(set(k[0] for k in wfcrl_results.keys()))\n",
    "x = np.arange(len(turbines))\n",
    "width = 0.25\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(4, 3))\n",
    "\n",
    "for i, algo in enumerate(algos):\n",
    "    means = []\n",
    "    errs = []\n",
    "    for t in turbines:\n",
    "        base = wfcrl_results[(t, \"wfcrl_random\")][\"mu\"]  # random baseline\n",
    "        mu = (wfcrl_results[(t, algo)][\"mu\"] - base) / base\n",
    "        err = wfcrl_results[(t, algo)][\"std_err\"] / base\n",
    "        means.append(mu)\n",
    "        errs.append(err)\n",
    "\n",
    "    ax.bar(\n",
    "        x + i * width,\n",
    "        means,\n",
    "        width,\n",
    "        yerr=errs,\n",
    "        label=labels[algo],\n",
    "        color=colors[algo],\n",
    "        capsize=5,\n",
    "    )\n",
    "\n",
    "# Add horizontal line for \"0% improvement\"\n",
    "ax.axhline(0.0, color=\"gray\", linestyle=\"--\", linewidth=1)\n",
    "\n",
    "ax.set_xticks(x + width, [f\"{t}\" for t in turbines])\n",
    "ax.set_xlabel(\"Number of Turbines\")\n",
    "ax.set_ylabel(\"Relative Reward Δ (Random)\")\n",
    "ax.set_title(\"Performance Scaling on WFCRL\")\n",
    "\n",
    "ax.set_ylim(-0.1, 0.2)\n",
    "\n",
    "# Legend at bottom\n",
    "fig.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0), ncol=3)\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.show()\n",
    "\n",
    "fig.savefig(\"wfcrl-progression.png\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3f0428b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# VMAS\n",
    "api = wandb.Api()\n",
    "\n",
    "project_name = \"diffusion-co-design-vmas-obstacle_navigation_3\"\n",
    "runs = api.runs(path=project_name)\n",
    "\n",
    "total_steps = 201\n",
    "vmas_runs_dict = defaultdict(list)\n",
    "train_reward_key = \"train/reward/episode_reward_mean\"\n",
    "\n",
    "for run in tqdm(runs):\n",
    "    name = run.name\n",
    "    cfg = run.config\n",
    "    reward = get_full_history(run, train_reward_key)\n",
    "    run_data = {\"cfg\": cfg, \"reward\": reward}\n",
    "\n",
    "    run_data[\"designer_state_dict\"] = None\n",
    "    for artifact in run.logged_artifacts():\n",
    "        if artifact.name.startswith(\"designer_final\"):\n",
    "            artifact_dir = artifact.download()\n",
    "            state_dict = torch.load(\n",
    "                os.path.join(artifact_dir, \"designer_200.pt\"), map_location=device\n",
    "            )\n",
    "            run_data[\"designer_state_dict\"] = state_dict\n",
    "\n",
    "    vmas_runs_dict[name].append(run_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ee538f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, runs in vmas_runs_dict.items():\n",
    "    rewards = []\n",
    "    for x in runs:\n",
    "        reward = ema(x[\"reward\"])\n",
    "        if reward.shape[0] != 201:\n",
    "            continue\n",
    "        rewards.append(reward)\n",
    "    rewards = np.array(rewards)\n",
    "    mu = rewards.mean(axis=0)\n",
    "    print(key, f\"mean: {mu[-1]}\")\n",
    "    std_err = rewards.std(axis=0) / np.sqrt(rewards.shape[0]) * 1.96\n",
    "    print(f\"95%: {std_err[-1]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086d13ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "above_fixed_cont_avg = (\n",
    "    np.array([490 / 442, 430 / 387, 370 / 325, 2.29 / 2.24]).sum() / 4 - 1\n",
    ")\n",
    "print(\"Above Fixed Cont Avg:\", above_fixed_cont_avg)\n",
    "\n",
    "above_rl_cont_avg = (\n",
    "    np.array([490 / 485, 430 / 404, 370 / 323, 2.29 / 1.92]).sum() / 4 - 1\n",
    ")\n",
    "print(\"Above RL Cont Avg:\", above_rl_cont_avg)\n",
    "\n",
    "above_dr_cont_avg = (\n",
    "    np.array([490 / 443, 430 / 382, 370 / 314, 2.29 / 1.80]).sum() / 4 - 1\n",
    ")\n",
    "print(\"Above DR Cont Avg:\", above_dr_cont_avg)\n",
    "\n",
    "above_fixed_all = (\n",
    "    np.array([490 / 442, 430 / 387, 370 / 325, 2.29 / 2.24, 12.1 / 9.6]).sum() / 5 - 1\n",
    ")\n",
    "print(\"Above Fixed All:\", above_fixed_all)\n",
    "\n",
    "above_rl_all = (\n",
    "    np.array([490 / 485, 430 / 404, 370 / 323, 2.29 / 1.92, 12.1 / 8.7]).sum() / 5 - 1\n",
    ")\n",
    "print(\"Above RL All:\", above_rl_all)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86a35d20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Deprecated\n",
    "\n",
    "# sns.set_theme(context=\"notebook\")\n",
    "# mpl.rcParams[\"font.family\"] = \"monospace\"\n",
    "# fig, axs = plt.subplots(1, 2)\n",
    "# fig.set_size_inches(15, 4)\n",
    "\n",
    "# key_to_label_map = {\n",
    "#     0: {\n",
    "#         \"wfcrl_fixed\": \"Fixed\",\n",
    "#         \"wfcrl_diffusion_distill\": \"DiCoDe\",\n",
    "#     },\n",
    "#     1: {\n",
    "#         \"wfcrl_fixed_rect_8\": \"Fixed\",\n",
    "#         \"wfcrl_diffusion_distill_rect_8\": \"DiCoDe\",\n",
    "#     },\n",
    "# }\n",
    "# total_training_iterations = 301\n",
    "# samples_per_iteration = 300\n",
    "# colors = sns.color_palette(n_colors=2)\n",
    "\n",
    "\n",
    "# for ax_id, data_dict in key_to_label_map.items():\n",
    "#     ax = axs[ax_id]\n",
    "\n",
    "#     for (key, label), color in zip(data_dict.items(), colors):\n",
    "#         runs = wfcrl_runs_dict[key]\n",
    "\n",
    "#         rewards = []\n",
    "#         for x in runs:\n",
    "#             reward = ema(x[\"reward\"])\n",
    "#             if len(x[\"reward\"]) != total_training_iterations:\n",
    "#                 # Run not complete, skip\n",
    "#                 continue\n",
    "#             rewards.append(reward)\n",
    "#         rewards = np.array(rewards)\n",
    "\n",
    "#         X = (\n",
    "#             np.linspace(1, total_training_iterations + 1, total_training_iterations)\n",
    "#             * samples_per_iteration\n",
    "#         )\n",
    "\n",
    "#         mu = rewards.mean(axis=0)\n",
    "#         print(ax_id, label, f\"mean: {mu[-1]}\")\n",
    "#         ax.plot(X, mu, color=color, label=label)\n",
    "#         if rewards.shape[0] > 1:\n",
    "#             std_err = rewards.std(axis=0)\n",
    "#             print(f\"std: {std_err[-1]}\")\n",
    "#             ax.fill_between(X, y1=mu - std_err, y2=mu + std_err, color=color, alpha=0.3)\n",
    "#         pass\n",
    "\n",
    "# fig.suptitle(\"WFCRL Training Progress\")\n",
    "# axs[0].set_title(\"Square-10\")\n",
    "# axs[0].set_xlabel(\"Frames\")\n",
    "# axs[0].set_ylabel(\"Mean Episode Reward\")\n",
    "# axs[1].set_title(\"Rect-8\")\n",
    "# axs[1].set_xlabel(\"Frames\")\n",
    "# axs[1].set_ylabel(\"Mean Episode Reward\")\n",
    "# handles, labels = axs[0].get_legend_handles_labels()\n",
    "# fig.legend(\n",
    "#     handles,\n",
    "#     labels,\n",
    "#     loc=\"upper center\",\n",
    "#     bbox_to_anchor=(0.5, 0.02),\n",
    "#     ncol=4,\n",
    "#     frameon=False,\n",
    "# )\n",
    "# fig.savefig(fname=\"wfcrl-training.png\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "073910ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i, data_dict in key_to_label_map.items():\n",
    "#     for j, (key, label) in enumerate(data_dict.items()):\n",
    "#         runs = wfcrl_runs_dict[key]\n",
    "#         scenario = WfcrlScenarioConfig.from_raw(runs[0][\"cfg\"][\"scenario\"])\n",
    "#         training_cfg = WfcrlTrainingConfig.from_raw(runs[0][\"cfg\"])\n",
    "\n",
    "#         if label == \"DiCoDe\":\n",
    "#             diffusion = training_cfg.designer.diffusion.model_copy()\n",
    "#             diffusion.forward_guidance_annealing = False\n",
    "#             designer = DiffusionDesigner(\n",
    "#                 scenario=scenario,\n",
    "#                 classifier=training_cfg.designer.model,\n",
    "#                 diffusion=diffusion,\n",
    "#                 normalisation_statistics=training_cfg.normalisation,\n",
    "#                 total_training_iterations=10,\n",
    "#             )\n",
    "#             designer.model.load_state_dict(state_dict=runs[0][\"designer_state_dict\"])\n",
    "#             layout = designer._reset_env_buffer(1)[0]\n",
    "#         else:\n",
    "#             generator = Generate(\n",
    "#                 num_turbines=scenario.n_turbines,\n",
    "#                 map_x_length=scenario.map_x_length,\n",
    "#                 map_y_length=scenario.map_y_length,\n",
    "#                 minimum_distance_between_turbines=scenario.min_distance_between_turbines,\n",
    "#             )\n",
    "#             layout = generator(n=1, training_dataset=False).squeeze()\n",
    "\n",
    "#         im = render_layout(x=layout, scenario=scenario)\n",
    "#         filename = f\"{label.lower()}_{scenario.name}.png\"\n",
    "#         mpl.image.imsave(filename, im)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
