{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac6fc1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from conf.schema import Config\n",
    "from diffusion_co_design.common import cuda as device\n",
    "from diffusion_co_design.wfcrl.schema import TrainingConfig\n",
    "from diffusion_co_design.wfcrl.model.classifier import GNNCritic\n",
    "from diffusion_co_design.wfcrl.model.rl import maybe_make_denormaliser\n",
    "from diffusion_co_design.wfcrl.diffusion.generator import eval_to_train\n",
    "from diffusion_co_design.wfcrl.diffusion.generator import Generator, OptimizerDetails\n",
    "from diffusion_co_design.wfcrl.diffusion.constraints import soft_projection_constraint\n",
    "from diffusion_co_design.wfcrl.env import _create_designable_windfarm\n",
    "from diffusion_co_design.common import OUTPUT_DIR, get_latest_model\n",
    "\n",
    "from dataset import load_dataset\n",
    "\n",
    "\n",
    "training_cfg = TrainingConfig.from_file(\n",
    "    os.path.join(training_dir, \".hydra\", \"config.yaml\")\n",
    ")\n",
    "\n",
    "\n",
    "train_dataset, eval_dataset = load_dataset(\n",
    "    scenario=training_cfg.scenario,\n",
    "    training_dir=training_dir,\n",
    "    dataset_size=10_000,\n",
    "    num_workers=25,\n",
    "    test_proportion=0.2,\n",
    "    recompute=False,\n",
    "    device=device,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "147e890f",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_returns = train_dataset.dataset.env_returns.storage\n",
    "\n",
    "episode_reward = env_returns[\"episode_reward\"]\n",
    "expected_reward = env_returns[\"expected_reward\"]\n",
    "\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "ax.hist(episode_reward, label=\"Sample\")\n",
    "ax.hist(expected_reward, label=\"Critic\")\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e220e49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, sorted_idxs = torch.sort(episode_reward)\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "ax.scatter(episode_reward[sorted_idxs], expected_reward[sorted_idxs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62005e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = Config.from_file(\"conf/config.yaml\")\n",
    "\n",
    "model = GNNCritic(\n",
    "    cfg=training_cfg.scenario,\n",
    "    node_emb_dim=cfg.model.node_emb_size,\n",
    "    n_layers=cfg.model.depth,\n",
    "    connectivity=cfg.model.connectivity,\n",
    "    post_hook=maybe_make_denormaliser(training_cfg.normalisation),\n",
    ").to(device=device)\n",
    "\n",
    "\n",
    "# Load designer\n",
    "model.load_state_dict(torch.load())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7531b01c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = eval_to_train(env_returns[\"env\"][:1024].clone(), training_cfg.scenario).to(\n",
    "    device, torch.float32\n",
    ")\n",
    "\n",
    "y_pred = model(X)\n",
    "actual_y = expected_reward[:1024]\n",
    "\n",
    "fig, ax = plt.subplots(1, 1)\n",
    "ax.scatter(actual_y, y_pred.numpy(force=True))\n",
    "# ax.hist(actual_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dd210dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrain_dir = os.path.join(\n",
    "    OUTPUT_DIR, \"wfcrl\", \"diffusion\", training_cfg.scenario_name\n",
    ")\n",
    "latest_checkpoint = get_latest_model(pretrain_dir, \"model\")\n",
    "\n",
    "generator = Generator(\n",
    "    generator_model_path=latest_checkpoint,\n",
    "    scenario=training_cfg.scenario,\n",
    "    batch_size=9,\n",
    "    default_guidance_wt=5,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "operation = OptimizerDetails()\n",
    "operation.projection_constraint = soft_projection_constraint(training_cfg.scenario)\n",
    "generator.guidance_weight = 4.0\n",
    "operation.num_recurrences = 8\n",
    "operation.lr = 0.01\n",
    "operation.backward_steps = 16\n",
    "operation.use_forward = True\n",
    "\n",
    "\n",
    "batch = generator.generate_batch(\n",
    "    value=model, use_operation=True, operation_override=operation, batch_size=9\n",
    ")\n",
    "X = eval_to_train(torch.tensor(batch).to(device), cfg=training_cfg.scenario)\n",
    "y = model(X)\n",
    "print(y.mean())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07cd698",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in batch:\n",
    "    env = _create_designable_windfarm(\n",
    "        scenario=training_cfg.scenario,\n",
    "        initial_xcoords=x[:, 0].tolist(),\n",
    "        initial_ycoords=x[:, 1].tolist(),\n",
    "        render=True,\n",
    "    )\n",
    "\n",
    "    env.reset()\n",
    "    fig, ax = plt.subplots(figsize=(4, 4))\n",
    "    ax.axis(\"off\")\n",
    "    ax.imshow(env.render(), aspect=\"auto\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
