{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n",
      "2024-03-31 22:50:16,508\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "WARNING:evotorch:The logger is already configured. The default configuration will not be applied. Call `set_default_logger_config` with `override=True` to override the current configuration.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance\n",
    "from utils.attack import load_attack_class, attack_storage_label, create_attack_instance\n",
    "\n",
    "from robust_diffusion.data import SparseGraph\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "from robust_diffusion.helper import utils as robust_utils\n",
    "from robust_diffusion.train import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "model_name = \"GCN\"\n",
    "n_splits = 50\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "attack_name = \"PRBCD\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n",
      "Loading dataset = cora_ml\n",
      "Found 50 splits!\n",
      "Loading pretrained GCN model on cora_ml dataset for 50 splits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/50 [00:00<?, ?it/s]/opt/conda/lib/python3.8/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
      "  warnings.warn(msg)\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.18it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 35.73it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.43it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.65it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.11it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.04it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.33it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.49it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 35.91it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.51it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.11it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.77it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.67it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.47it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 35.89it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.98it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.82it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.44it/s]\n",
      "100%|██████████| 400/400 [00:11<00:00, 36.06it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.50it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.38it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.54it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 37.04it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.61it/s]\n",
      "100%|██████████| 400/400 [00:10<00:00, 36.52it/s]\n",
      " 83%|████████▎ | 331/400 [00:09<00:01, 35.95it/s]\n",
      " 50%|█████     | 25/50 [04:57<04:57, 11.91s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m<cell line: 34>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     67\u001b[0m model_storage_name \u001b[38;5;241m=\u001b[39m model_instance[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_storage_name\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m     68\u001b[0m clean_accs\u001b[38;5;241m.\u001b[39mappend(acc)\n\u001b[0;32m---> 70\u001b[0m attack \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_attack_instance\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     71\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattack_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattack_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattack_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattack_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepsilon\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepsilon\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     72\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtest_attr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_adj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_adj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     73\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdataset_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minductive\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minductive\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m     74\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtest_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdefault_attack_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdefault_attack_configs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     75\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreports_root\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreports_root\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     76\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     77\u001b[0m pert_acc \u001b[38;5;241m=\u001b[39m attack[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpert_acc\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m     78\u001b[0m pert_accs\u001b[38;5;241m.\u001b[39mappend(pert_acc)\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/EVAttack/experiments/utils/attack.py:62\u001b[0m, in \u001b[0;36mcreate_attack_instance\u001b[0;34m(attack_name, attack_params, epsilon, test_attr, test_adj, labels, model, dataset_info, inductive, split_name, test_mask, default_attack_configs, reports_root, device)\u001b[0m\n\u001b[1;32m     60\u001b[0m \u001b[38;5;66;03m# TODO: Check if for undirected this number should be devided\u001b[39;00m\n\u001b[1;32m     61\u001b[0m n_attack_edge \u001b[38;5;241m=\u001b[39m (n_feasible_edges \u001b[38;5;241m*\u001b[39m epsilon)\u001b[38;5;241m.\u001b[39mint()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[0;32m---> 62\u001b[0m \u001b[43madversary\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_attack_edge\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     63\u001b[0m pert_adj, pert_attr \u001b[38;5;241m=\u001b[39m adversary\u001b[38;5;241m.\u001b[39mget_pertubations()\n\u001b[1;32m     65\u001b[0m pert_acc \u001b[38;5;241m=\u001b[39m accuracy(\n\u001b[1;32m     66\u001b[0m     model\u001b[38;5;241m=\u001b[39mmodel, attr\u001b[38;5;241m=\u001b[39mpert_attr, adj\u001b[38;5;241m=\u001b[39mpert_adj, \n\u001b[1;32m     67\u001b[0m     labels\u001b[38;5;241m=\u001b[39mlabels, evaluation_mask\u001b[38;5;241m=\u001b[39mtest_mask)\n",
      "File \u001b[0;32m/opt/conda/lib/python3.8/site-packages/typeguard/__init__.py:912\u001b[0m, in \u001b[0;36mtypechecked.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    910\u001b[0m memo \u001b[38;5;241m=\u001b[39m _CallMemo(python_func, _localns, args\u001b[38;5;241m=\u001b[39margs, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m    911\u001b[0m check_argument_types(memo)\n\u001b[0;32m--> 912\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    913\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    914\u001b[0m     check_return_type(retval, memo)\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training/robust_diffusion/attacks/base_attack.py:129\u001b[0m, in \u001b[0;36mAttack.attack\u001b[0;34m(self, n_perturbations, **kwargs)\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    120\u001b[0m \u001b[38;5;124;03mExecutes the attack on the model updating the attributes\u001b[39;00m\n\u001b[1;32m    121\u001b[0m \u001b[38;5;124;03mself.adj_adversary and self.attr_adversary accordingly.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    126\u001b[0m \u001b[38;5;124;03m    number of perturbations (attack budget in terms of node additions/deletions) that constrain the atack\u001b[39;00m\n\u001b[1;32m    127\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_perturbations \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 129\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_attack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_perturbations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    130\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    131\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr_adversary \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training/robust_diffusion/attacks/prbcd.py:131\u001b[0m, in \u001b[0;36mPRBCD._attack\u001b[0;34m(self, n_perturbations, **kwargs)\u001b[0m\n\u001b[1;32m    129\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_modified_adj()\n\u001b[1;32m    130\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattacked_model(data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattr\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice), adj\u001b[38;5;241m=\u001b[39m(edge_index, edge_weight))\n\u001b[0;32m--> 131\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m \u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlogits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43midx_attack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    132\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m edge_index, edge_weight, logits\n\u001b[1;32m    134\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m epoch \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdisplay_step \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "File \u001b[0;32m~/anonymous-home/projects/eva-evolutionary-attacks-on-graphs/adversarial_training/robust_diffusion/helper/utils.py:665\u001b[0m, in \u001b[0;36maccuracy\u001b[0;34m(logits, labels, split_idx)\u001b[0m\n\u001b[1;32m    648\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21maccuracy\u001b[39m(logits: torch\u001b[38;5;241m.\u001b[39mTensor, labels: torch\u001b[38;5;241m.\u001b[39mTensor, split_idx: np\u001b[38;5;241m.\u001b[39mndarray) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mfloat\u001b[39m:\n\u001b[1;32m    649\u001b[0m     \u001b[38;5;124;03m\"\"\"Returns the accuracy for a tensor of logits, a list of lables and and a split indices.\u001b[39;00m\n\u001b[1;32m    650\u001b[0m \n\u001b[1;32m    651\u001b[0m \u001b[38;5;124;03m    Parameters\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    663\u001b[0m \u001b[38;5;124;03m        the Accuracy\u001b[39;00m\n\u001b[1;32m    664\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 665\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margmax\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[43msplit_idx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m==\u001b[39m labels[split_idx])\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem()\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "if creating_splits > 0:\n",
    "    raise ValueError(\"Not enough splits for the dataset. Create the splits by running training scripts.\")\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits!\")\n",
    "\n",
    "print(f\"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits\")\n",
    "\n",
    "clean_accs = []\n",
    "pert_accs = []\n",
    "for split_file in tqdm(dataset_splits[:n_splits]):\n",
    "    split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "    data = load_dataset_splits(\n",
    "        dataset_name, split_code, inductive=inductive, \n",
    "        dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "    training_attr = data[\"training_attr\"]\n",
    "    training_adj = data[\"training_adj\"]\n",
    "    labels = data[\"labels\"]\n",
    "    training_idx = data[\"training_idx\"]\n",
    "    validation_idx = data[\"validation_idx\"]\n",
    "    test_attr = data[\"test_attr\"]\n",
    "    test_adj = data[\"test_adj\"]\n",
    "    test_mask = data[\"test_mask\"]\n",
    "    test_idx = test_mask.nonzero(as_tuple=True)[0]\n",
    "    dataset_info = data[\"dataset_info\"]\n",
    "    split_name = data[\"split_name\"]\n",
    "\n",
    "    try:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, split_name=split_name, dataset_info=dataset_info, inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    except FileNotFoundError as e:\n",
    "        print(e)\n",
    "        raise ValueError(\"Model not found. Run training scripts to train the model.\")\n",
    "\n",
    "\n",
    "    model = model_instance[\"model\"]\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_params = model_instance[\"model_params\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    clean_accs.append(acc)\n",
    "\n",
    "    attack = create_attack_instance(\n",
    "        attack_name=attack_name, attack_params=attack_params, epsilon=epsilon,\n",
    "        test_attr=test_attr, test_adj=test_adj, labels=labels, model=model,\n",
    "        dataset_info=dataset_info, model_storage_name=model_storage_name, \n",
    "        split_name=split_name, test_mask=test_mask, \n",
    "        default_attack_configs=default_attack_configs, reports_root=reports_root,\n",
    "        device=device)\n",
    "    pert_acc = attack[\"pert_acc\"]\n",
    "    pert_accs.append(pert_acc)\n",
    "mean_clean_acc = torch.mean(torch.tensor(clean_accs))\n",
    "mean_pert_acc = torch.mean(torch.tensor(pert_accs))\n",
    "std_clean_acc = torch.std(torch.tensor(clean_accs))\n",
    "std_pert_acc = torch.std(torch.tensor(pert_accs))\n",
    "\n",
    "print(f\"Mean clean accuracy: {mean_clean_acc:.4f} $\\\\pm$ {std_clean_acc:.4f}\")\n",
    "print(f\"Mean perturbed accuracy: {mean_pert_acc:.4f} $\\\\pm$ {std_pert_acc:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'GCN-0ae4e175b2-tr-cora_ml-8804994889-'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_storage_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'PRBCD-86c23e7c47-0_1-cora_ml-tr-193ba13820'"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5189686924493554"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'epochs': 400,\n",
       " 'fine_tune_epochs': 100,\n",
       " 'keep_heuristic': 'WeightOnly',\n",
       " 'search_space_size': 500000,\n",
       " 'do_synchronize': True,\n",
       " 'lr_factor': 100,\n",
       " 'loss_type': 'tanhMargin'}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attack_configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'PRBCD': {'epochs': 400,\n",
       "  'fine_tune_epochs': 100,\n",
       "  'keep_heuristic': 'WeightOnly',\n",
       "  'search_space_size': 500000,\n",
       "  'do_synchronize': True,\n",
       "  'lr_factor': 100,\n",
       "  'loss_type': 'tanhMargin'},\n",
       " 'LRBCD': {'epochs': 400,\n",
       "  'fine_tune_epochs': 100,\n",
       "  'keep_heuristic': 'WeightOnly',\n",
       "  'search_space_size': 500000,\n",
       "  'do_synchronize': True,\n",
       "  'lr_factor': 100,\n",
       "  'loss_type': 'tanhMargin'},\n",
       " 'PGD': {'epochs': 400, 'base_lr': 0.1, 'loss_type': 'tanhMargin'}}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "default_attack_configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['cora_ml-193ba13820.pt',\n",
       " 'cora_ml-07e338df9c.pt',\n",
       " 'cora_ml-ebaab37b41.pt',\n",
       " 'cora_ml-d35511362f.pt',\n",
       " 'cora_ml-ac40e4949a.pt',\n",
       " 'cora_ml-1da6f9719b.pt',\n",
       " 'cora_ml-8f758fcc6c.pt',\n",
       " 'cora_ml-0ae39ac841.pt',\n",
       " 'cora_ml-8d29ff13d0.pt',\n",
       " 'cora_ml-8804994889.pt',\n",
       " 'cora_ml-d5e6390e2a.pt',\n",
       " 'cora_ml-e397188883.pt',\n",
       " 'cora_ml-1240d69316.pt',\n",
       " 'cora_ml-09b45831a3.pt',\n",
       " 'cora_ml-3d644bf71b.pt',\n",
       " 'cora_ml-f67ecd84c6.pt',\n",
       " 'cora_ml-6acbc00713.pt',\n",
       " 'cora_ml-e262894156.pt',\n",
       " 'cora_ml-63696f9605.pt',\n",
       " 'cora_ml-e7f9844bb3.pt',\n",
       " 'cora_ml-172a3afaf8.pt',\n",
       " 'cora_ml-8977893341.pt',\n",
       " 'cora_ml-4d0b76e1fe.pt',\n",
       " 'cora_ml-22cd69f64a.pt',\n",
       " 'cora_ml-fdc47a877a.pt',\n",
       " 'cora_ml-db21959af5.pt',\n",
       " 'cora_ml-7c6acb9f09.pt',\n",
       " 'cora_ml-94bb1221d9.pt',\n",
       " 'cora_ml-a19dad2d8b.pt',\n",
       " 'cora_ml-8a6b60d222.pt',\n",
       " 'cora_ml-a5fb595e97.pt',\n",
       " 'cora_ml-a05c63f81d.pt',\n",
       " 'cora_ml-86502bcd09.pt',\n",
       " 'cora_ml-8d736fb2a8.pt',\n",
       " 'cora_ml-369cc457c9.pt',\n",
       " 'cora_ml-1bd879c44c.pt',\n",
       " 'cora_ml-b9cb5c84b7.pt',\n",
       " 'cora_ml-441a16c245.pt',\n",
       " 'cora_ml-ed7f63338c.pt',\n",
       " 'cora_ml-4b9df2abbb.pt',\n",
       " 'cora_ml-54794942b2.pt',\n",
       " 'cora_ml-19152d7e72.pt',\n",
       " 'cora_ml-0a8fc274d0.pt',\n",
       " 'cora_ml-da3dfef1fa.pt',\n",
       " 'cora_ml-ba8a934d87.pt',\n",
       " 'cora_ml-2259d56da3.pt',\n",
       " 'cora_ml-9d1568e157.pt',\n",
       " 'cora_ml-3f09fbf6fb.pt',\n",
       " 'cora_ml-c6b817d797.pt',\n",
       " 'cora_ml-26cd1d5e2f.pt']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits, creating {creating_splits} more splits\")\n",
    "for i in tqdm(range(creating_splits)):\n",
    "    torch.cuda.empty_cache()\n",
    "    data = make_dataset_splits(dataset_name=dataset_name, \n",
    "                            training_split=training_split, validation_split=validation_split, \n",
    "                            training_split_type=training_split_type, validation_split_type=validation_split_type, \n",
    "                            inductive=inductive, \n",
    "                            default_dataset_configs=default_dataset_configs, dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name][:n_splits]\n",
    "print(f\"Training {model_name} model on {dataset_name} dataset for {n_splits} splits\")\n",
    "\n",
    "accs = []\n",
    "for split_file in tqdm(dataset_splits):\n",
    "    split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "    data = load_dataset_splits(\n",
    "        dataset_name, split_code, inductive=inductive, \n",
    "        dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "    training_attr = data[\"training_attr\"]\n",
    "    training_adj = data[\"training_adj\"]\n",
    "    labels = data[\"labels\"]\n",
    "    training_idx = data[\"training_idx\"]\n",
    "    validation_idx = data[\"validation_idx\"]\n",
    "    test_attr = data[\"test_attr\"]\n",
    "    test_adj = data[\"test_adj\"]\n",
    "    test_mask = data[\"test_mask\"]\n",
    "    dataset_info = data[\"dataset_info\"]\n",
    "    split_name = data[\"split_name\"]\n",
    "\n",
    "    try:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, split_name=split_name, dataset_info=dataset_info, inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    except FileNotFoundError as e:\n",
    "        print(e)\n",
    "        print(\"Creating model from scratch\")\n",
    "        model_instance = create_model_instance(\n",
    "            model_name=model_name, model_params=model_params, dataset_info=dataset_info, \n",
    "            training_attr=training_attr, training_adj=training_adj, labels=labels, training_idx=training_idx, validation_idx=validation_idx,\n",
    "            test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, inductive=inductive, split_name=split_name,\n",
    "            models_root=models_root, \n",
    "            default_model_configs=default_model_configs, \n",
    "            device=device)\n",
    "\n",
    "    model = model_instance[\"model\"]\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_params = model_instance[\"model_params\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    accs.append(acc)\n",
    "\n",
    "acc_mean = torch.mean(torch.tensor(accs))\n",
    "acc_std = torch.std(torch.tensor(accs))\n",
    "\n",
    "print(f\"Mean accuracy: {acc_mean}, std: {acc_std}\")\n",
    "\n",
    "print(\"Experiment Finished\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
