{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "398a17de-126f-4877-ae39-22aaa4c346c1",
   "metadata": {},
   "source": [
    "## Classifier guidance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d47af575-9464-4f22-9b07-6236a89edac9",
   "metadata": {},
   "source": [
    "#### Model & data path setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a285109b-f379-4be4-8363-93a1b1ae71e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Enter path to the unconditional generative model\n",
    "GEN_MODEL_PATH = 'model_weights/OCB_FULL_GENERATION'\n",
    "\n",
    "## Enter path to the noisy classifier\n",
    "CLASSIFIER_PATH = 'model_weights/CLASSIFIER'\n",
    "\n",
    "## OCB101 path \n",
    "DATASET_PATH = 'datasets/CktBench101'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65f09770-35b2-4bc2-b0ad-189787855aef",
   "metadata": {},
   "source": [
    "#### Import + config initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "524a920b-84a4-4268-ba2e-7d923569d763",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, pickle, torch, glob\n",
    "import numpy as np\n",
    "from src.graphgym.loader.datasets.ocb_dataset import OCBDataset\n",
    "from torch_geometric.graphgym.config import (cfg, set_cfg)\n",
    "from torch_geometric.graphgym.model_builder import create_model\n",
    "from src.graphgym.utils import simul_outputs_to_bin_idx\n",
    "from sklearn.metrics import accuracy_score\n",
    "from src.graphgym.utils import gym_to_igraph\n",
    "from src.graphgym.pyspice_utils import simulation\n",
    "from src.graphgym.inference_utils import inference\n",
    "\n",
    "cfg_filep = os.path.join(GEN_MODEL_PATH, 'config.yaml')\n",
    "\n",
    "set_cfg(cfg)\n",
    "cfg.set_new_allowed(True)\n",
    "cfg.work_dir = os.getcwd()\n",
    "cfg.merge_from_file(cfg_filep)\n",
    "cfg.cfg_file = cfg_filep\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "else:\n",
    "    device = 'cpu'\n",
    "cfg.accelerator = device\n",
    "cfg.device = device\n",
    "cfg.gt.conditioning_loss = 'cg'\n",
    "\n",
    "## Classifier guidance strength (feature & node types) can be adjusted here ##\n",
    "cfg.gt.guidance_strength_features = 20\n",
    "cfg.gt.guidance_strength = 10\n",
    "##"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3e2c755-59b6-4447-800e-beb03fff9560",
   "metadata": {},
   "source": [
    "#### Generative model loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "26205cd6-568e-4b9c-b944-2a333fa4003a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = glob.glob(GEN_MODEL_PATH + '/*.ckpt')[0]\n",
    "model = create_model()\n",
    "\n",
    "checkpoint = torch.load(ckpt_path)\n",
    "model.load_state_dict(checkpoint['model_state'])\n",
    "model.eval()\n",
    "model.to(torch.device(cfg.device))\n",
    "print(f\"Loaded checkpoint from {ckpt_path}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06151ce8-ea38-480d-ac37-58794ab5d166",
   "metadata": {},
   "source": [
    "#### Load test set for conditions - alternatively, the conditions can be defined manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b0ef705a-2d07-4f14-af8d-a200833c4026",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test set\n",
    "DATASET_PATH\n",
    "dataset = OCBDataset(root=DATASET_PATH, split='test')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "280d9fa2-9b74-4d7b-bffe-90887db795d5",
   "metadata": {},
   "source": [
    "#### Conditional circuit generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "44ea10d5-9fa2-4da4-84d9-f6dee79501f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set output directory\n",
    "save_dir = os.path.join(GEN_MODEL_PATH, 'SIMULATION_OUTPUTS')\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "G = []\n",
    "failed_idx = []\n",
    "\n",
    "## Loop over the test set\n",
    "num_samples = 50\n",
    "for b_idx in range(1 + len(dataset) // num_samples):\n",
    "\n",
    "    spec_idx = np.arange(b_idx * num_samples, (b_idx + 1) * num_samples)\n",
    "    y_guidance = dataset.y[spec_idx]\n",
    "\n",
    "    ## Denoising\n",
    "    with torch.no_grad():\n",
    "        denoised_batch = inference(model, num_samples=num_samples, euler_steps=100, noise_e=0.1, noise_x=0.1, n_pow_e=8, n_pow_x=6, \n",
    "                                   n_pow_f=5, y_guidance=y_guidance, classifier_path=CLASSIFIER_PATH)\n",
    "        \n",
    "        for i in range(denoised_batch.num_graphs):  # Number of graphs in the batch\n",
    "            graph = denoised_batch.get_example(i).to('cpu')\n",
    "            i_graph = gym_to_igraph(graph)\n",
    "            G.append((spec_idx[i], i_graph))\n",
    "\n",
    "    print(f'Batch {b_idx} done')\n",
    "                \n",
    "## Save results        \n",
    "for (i, graph) in G:\n",
    "    with open(os.path.join(save_dir, f'out_{i}'), 'wb') as f:\n",
    "        pickle.dump(graph, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fed11cd6-490d-4540-95b8-c7a816dca1bd",
   "metadata": {},
   "source": [
    "#### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "bc6be53a-9c27-4436-ab74-0217929128ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "simul_res = {}\n",
    "\n",
    "save_dir = os.path.join(GEN_MODEL_PATH, 'SIMULATION_OUTPUTS')\n",
    "\n",
    "out_paths = glob.glob(save_dir + '/out*')\n",
    "G, indices = [], []\n",
    "for p in out_paths:\n",
    "    indices.append(int(os.path.basename(p).split('_')[-1]))\n",
    "    with open(p, 'rb') as f:\n",
    "        g = pickle.load(f)\n",
    "    G.append(g)\n",
    "\n",
    "y_guidance = dataset.y[indices]\n",
    "\n",
    "simulation_out, valid_specs, valid_idx = [], [], []\n",
    "for i in range(len(G)):  # Number of graphs in the batch\n",
    "    i_graph = G[i]\n",
    "\n",
    "    try:\n",
    "        sim = simulation(i_graph)\n",
    "        val_sim, sim_out = 1, (np.round(float(sim.gain[0] / 100), 3), np.round(float(sim.ugw / 1e9), 3), np.round(float(sim.pm / 90), 3))\n",
    "    except:\n",
    "        val_sim, sim_out = 0, (None, None, None)\n",
    "    \n",
    "    if sim_out[0] is not None:\n",
    "        simulation_out.append(sim_out)\n",
    "        valid_specs.append(y_guidance[i].numpy())\n",
    "        valid_idx.append(i)\n",
    "\n",
    "if len(simulation_out) > 1:\n",
    "    simulation_out_bins = simul_outputs_to_bin_idx(np.array(simulation_out))    \n",
    "    valid_specs = np.array(valid_specs).squeeze()\n",
    "    # Gain\n",
    "    accuracy_gain = accuracy_score(simulation_out_bins[:, 0].tolist(), valid_specs[:, 0])\n",
    "    # Unit-gain frequency\n",
    "    accuracy_bw = accuracy_score(simulation_out_bins[:, 1].tolist(), valid_specs[:, 1])\n",
    "    # Phase margin\n",
    "    accuracy_pm = accuracy_score(simulation_out_bins[:, 2].tolist(), valid_specs[:, 2])\n",
    "    # Joint accuracy\n",
    "    joint_acc = ((simulation_out_bins == valid_specs).sum(dim=-1) == 3).sum() / len(valid_specs)\n",
    "    \n",
    "    reformat = lambda x: round(float(x), cfg.round)\n",
    "    accuracy_gain, accuracy_bw, accuracy_pm = reformat(accuracy_gain), reformat(accuracy_bw), reformat(accuracy_pm)\n",
    "    joint_accuracy = reformat(joint_acc)\n",
    "    validity = len(simulation_out) / len(G)\n",
    "    simul_res = f'Acc. Gain: {accuracy_gain}, Acc. UGf: {accuracy_bw}, Acc. PM: {accuracy_pm}, Joint accuracy: {joint_accuracy}, Validity: {validity}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f1c1f78b-5b72-4059-bcdb-d149d5fe570e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simul_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e29dca-350f-48bd-9744-f5a87e0c1ac0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
