{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91295d07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from diffusion_co_design.wfcrl.schema import TrainingConfig, Diffusion, ClassifierConfig\n",
    "from diffusion_co_design.wfcrl.diffusion.generator import (\n",
    "    OptimizerDetails,\n",
    "    soft_projection_constraint,\n",
    "    eval_to_train,\n",
    ")\n",
    "from diffusion_co_design.wfcrl.design import FixedDesigner, DesignerRegistry\n",
    "from diffusion_co_design.wfcrl.env import create_env, _create_designable_windfarm\n",
    "from diffusion_co_design.wfcrl.model.rl import wfcrl_models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45e3dd4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg = TrainingConfig.from_file()\n",
    "scenario = train_cfg.scenario\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "designer = FixedDesigner(scenario=scenario)\n",
    "env = create_env(mode=\"reference\", scenario=scenario, designer=designer, device=device)\n",
    "env.check_env_specs()\n",
    "policy, critic = wfcrl_models(\n",
    "    env,\n",
    "    train_cfg.policy,\n",
    "    normalisation=train_cfg.normalisation,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "diffusion, _ = DesignerRegistry.get(\n",
    "    designer=Diffusion(\n",
    "        type=\"diffusion\",\n",
    "        model=ClassifierConfig(\n",
    "            node_emb_size=64,\n",
    "            edge_emb_size=32,\n",
    "            depth=2,\n",
    "        ),\n",
    "    ),\n",
    "    artifact_dir=\".\",\n",
    "    normalisation_statistics=train_cfg.normalisation,\n",
    "    scenario=scenario,\n",
    "    ppo_cfg=train_cfg.ppo,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "# critic.load_state_dict(\n",
    "#     torch.load(\n",
    "#     )\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f09c980",
   "metadata": {},
   "outputs": [],
   "source": [
    "diffusion.master_designer.model.load_state_dict(torch.load())\n",
    "\n",
    "model = diffusion.master_designer.model\n",
    "\n",
    "generator = diffusion.master_designer.generator\n",
    "operation = OptimizerDetails()\n",
    "generator.guidance_weight = 2.0\n",
    "operation.num_recurrences = 4\n",
    "operation.lr = 0.01\n",
    "operation.backward_steps = 8\n",
    "operation.use_forward = True\n",
    "operation.projection_constraint = soft_projection_constraint(scenario)\n",
    "\n",
    "batch = generator.generate_batch(\n",
    "    value=model, use_operation=True, operation_override=operation, batch_size=9\n",
    ")\n",
    "X = eval_to_train(torch.tensor(batch).to(device), cfg=scenario)\n",
    "y = model(X)\n",
    "print(y.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "115a734a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in batch:\n",
    "    env = _create_designable_windfarm(\n",
    "        scenario=scenario,\n",
    "        initial_xcoords=x[:, 0].tolist(),\n",
    "        initial_ycoords=x[:, 1].tolist(),\n",
    "        render=True,\n",
    "    )\n",
    "\n",
    "    env.reset()\n",
    "    fig, ax = plt.subplots(figsize=(4, 4))\n",
    "    ax.axis(\"off\")\n",
    "    ax.imshow(env.render(), aspect=\"auto\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
