{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from sacred import Experiment\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance\n",
    "from utils.attack import load_attack_class\n",
    "\n",
    "\n",
    "from robust_diffusion.data import SparseGraph\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "from robust_diffusion.helper import utils as robust_utils\n",
    "from robust_diffusion.train import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loading general configs (like dataset_root, etc.)\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "\n",
    "## Experiment configs\n",
    "dataset_name = 'cora_ml'\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "\n",
    "model_name = \"GCN\"\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "attack_name = \"PRBCD\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = False\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy (Clean):  0.8169429097605893\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [01:26<00:00,  5.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy (Perturbed): 0.6423572744014733\n",
      "Experiment Finished\n"
     ]
    }
   ],
   "source": [
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "\n",
    "\n",
    "# Loading the dataset, creating splits, and saving them (for both transductive and inductive)\n",
    "# data = make_dataset_splits(dataset_name, \n",
    "#                            training_split, validation_split, training_split_type, validation_split_type, \n",
    "#                            inductive, \n",
    "#                            default_dataset_configs, dataset_root, splits_root, device)\n",
    "data = load_dataset_splits(dataset_name, \n",
    "                           \"0x30d202b2fcf2b06\",\n",
    "                            inductive=inductive, dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "training_attr = data[\"training_attr\"]\n",
    "training_adj = data[\"training_adj\"]\n",
    "labels = data[\"labels\"]\n",
    "training_idx = data[\"training_idx\"]\n",
    "validation_idx = data[\"validation_idx\"]\n",
    "test_attr = data[\"test_attr\"]\n",
    "test_adj = data[\"test_adj\"]\n",
    "test_mask = data[\"test_mask\"]\n",
    "dataset_info = data[\"dataset_info\"]\n",
    "split_name = data[\"split_name\"]\n",
    "\n",
    "\n",
    "# Loading and training the model\n",
    "# model_instance = create_model_instance(\n",
    "#     model_name=model_name, model_params=model_params, dataset_info=dataset_info, \n",
    "#     training_attr=training_attr, training_adj=training_adj, labels=labels, training_idx=training_idx, validation_idx=validation_idx,\n",
    "#     test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, inductive=inductive, split_name=split_name,\n",
    "#     models_root=models_root, \n",
    "#     default_model_configs=default_model_configs, \n",
    "#     device=device)\n",
    "\n",
    "model_instance = load_model_instance(\n",
    "    model_storage_name='GCN-0xbd14035a4016352-tr-cora_ml-0x30d202b2fcf2b06', \n",
    "    model_name=model_name, model_params=model_params, \n",
    "    test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, dataset_info=dataset_info, inductive=inductive,\n",
    "    models_root=models_root,\n",
    "    default_model_configs=default_model_configs, device=device)\n",
    "\n",
    "model = model_instance[\"model\"]\n",
    "acc = model_instance[\"accuracy\"]\n",
    "print(\"Accuracy (Clean): \", acc)\n",
    "model_params = model_instance[\"model_params\"]\n",
    "model_storage_name = model_instance[\"model_storage_name\"]\n",
    "\n",
    "idx_attack = test_mask.nonzero(as_tuple=True)[0].cpu().numpy()\n",
    "n_feasible_edges = count_edges_for_idx(test_adj.cpu(), idx_attack) / (2)\n",
    "n_attack_edges = (n_feasible_edges * epsilon).int().item()\n",
    "\n",
    "if attack_params is None:\n",
    "    attack_params = ConfigDict(default_attack_configs.get(attack_name))\n",
    "attack_params.device = device\n",
    "adversary = load_attack_class(attack_name)(\n",
    "    attr=test_attr, adj=test_adj, labels=labels, model=model, \n",
    "    idx_attack=test_mask.nonzero(as_tuple=True)[0].cpu().numpy(),\n",
    "    data_device=device, make_undirected=True, binary_attr=False,\n",
    "    **attack_params.to_dict())\n",
    "adversary.attack(n_attack_edges)\n",
    "pert_adj, pert_attr = adversary.get_pertubations()\n",
    "adv_acc = accuracy(model, pert_attr, pert_adj, labels, test_mask)\n",
    "\n",
    "attack_config_hash = TensorHash.hash_model_params(model_name=attack_name, model_params=attack_params)\n",
    "epsilon_str = str(epsilon).replace(\".\", \"_\")\n",
    "attack_storage_name = f\"{attack_name}-eps{epsilon_str}-{attack_config_hash}-{model_storage_name}\"\n",
    "\n",
    "try:\n",
    "    os.makedirs(results_root, exist_ok=True)\n",
    "    torch.save(pert_adj, os.path.join(results_root, f\"{attack_storage_name}-adj.pt\"))\n",
    "except RuntimeError as e:\n",
    "    print(f\"Error saving perturbed adj: {e}\")\n",
    "\n",
    "print(\"Accuracy (Perturbed):\", adv_acc)\n",
    "report = {\n",
    "    \"model\": model_name,\n",
    "    \"dataset\": dataset_name,\n",
    "    \"dataset_info\": dataset_info.to_dict(),\n",
    "    \"model_params\": model_params,\n",
    "    \"clean_accuracy\": acc,\n",
    "    \"setting\": \"inductive\" if inductive else \"transductive\",\n",
    "    \"attack\": attack_name,\n",
    "    \"attack_params\": attack_params,\n",
    "    \"attack_accuracy\": adv_acc,\n",
    "    \"epsilon\": epsilon,\n",
    "    \"n_attack_edges\": n_attack_edges,\n",
    "    \"model_storage_name\": model_storage_name,\n",
    "    \"attack_storage_name\": attack_storage_name\n",
    "}\n",
    "\n",
    "try:\n",
    "    os.makedirs(reports_root, exist_ok=True)\n",
    "    torch.save(report, os.path.join(reports_root, f\"{attack_storage_name}-report.pt\"))\n",
    "except RuntimeError as e:\n",
    "    print(f\"Error saving report: {e}\")\n",
    "\n",
    "print(\"Experiment Finished\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model': 'GCN',\n",
       " 'dataset': 'cora_ml',\n",
       " 'dataset_info': {'dataset_name': 'cora_ml',\n",
       "  'n_classes': 7,\n",
       "  'n_features': 2879,\n",
       "  'n_nodes': 2995},\n",
       " 'model_params': lr: 0.01\n",
       " max_epochs: 3000\n",
       " n_classes: 7\n",
       " n_features: 2879\n",
       " n_filters: 64\n",
       " patience: 200\n",
       " weight_decay: 0.001,\n",
       " 'clean_accuracy': 0.8169429097605893,\n",
       " 'setting': 'transductive',\n",
       " 'attack': 'PRBCD',\n",
       " 'attack_params': device: cuda\n",
       " do_synchronize: true\n",
       " epochs: 500\n",
       " fine_tune_epochs: 100\n",
       " keep_heuristic: WeightOnly\n",
       " loss_type: tanhMargin\n",
       " lr_factor: 100\n",
       " search_space_size: 500000,\n",
       " 'attack_accuracy': 0.6423572744014733,\n",
       " 'epsilon': 0.1,\n",
       " 'n_attack_edges': 805,\n",
       " 'model_storage_name': 'GCN-0xbd14035a4016352-tr-cora_ml-0x30d202b2fcf2b06',\n",
       " 'attack_storage_name': 'PRBCD-eps0_1-0x50d433fc013fbb99-GCN-0xbd14035a4016352-tr-cora_ml-0x30d202b2fcf2b06'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# report = {\n",
    "#     \"model\": model_name,\n",
    "#     \"model_params\": model_params.to_dict(),\n",
    "#     \"dataset\": dataset_name,\n",
    "#     \"dataset_info\": dataset_info.to_dict(),\n",
    "#     \"attack\": attack_name,\n",
    "#     \"attack_params\": attack_params.to_dict(),\n",
    "#     \"clean_accuracy\": acc,\n",
    "#     \"setting\": \"inductive\" if inductive else \"transductive\",\n",
    "#     \"attack_accuracy\": adv_acc,\n",
    "#     \"epsilon\": epsilon,\n",
    "#     \"n_attack_edges\": n_attack_edges,\n",
    "#     \"model_storage_name\": model_storage_name,\n",
    "#     \"attack_storage_name\": attack_storage_name\n",
    "# }\n",
    "\n",
    "report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [01:25<00:00,  5.82it/s]\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6799263351749539"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adv_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6523020257826887"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0x30d202b2fcf2b06'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6523020257826887"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'GCN-0xbd14035a4016352-tr-cora_ml-0x30d202b2fcf2b06'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_storage_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'GCN-0xbd14035a4016352-ind-cora_ml-0x30d202b2fcf2b06'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_storage_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0x30d202b2fcf2b06'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8386740331491712"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'GCN-0x1958de4e3815c4bf-tr-cora_ml-0x2f76270573c89b6'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "setting_str = \"ind\" if inductive else \"tr\"\n",
    "model_storage_name = f\"{model_name}-{model_config_hash}-{setting_str}-{dataset_name}-{split_name}\"\n",
    "model_storage_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0x21294382f5ba97d'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hex(abs(hash(model_params.to_json())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(model, attr, adj, labels, evaluation_mask):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        logits = model(attr, adj)\n",
    "        preds = logits.max(1)[1].type_as(labels)\n",
    "        acc = preds.eq(labels).double()\n",
    "        acc = acc[evaluation_mask].mean()\n",
    "    return acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1e-2'"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_params.lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GCN(\n",
       "  (activation): ReLU()\n",
       "  (layers): ModuleList(\n",
       "    (0): Sequential(\n",
       "      (gcn_0): ChainableGCNConv(2879, 64)\n",
       "      (activation_0): ReLU()\n",
       "      (dropout_0): Dropout(p=0.5, inplace=False)\n",
       "    )\n",
       "    (1): Sequential(\n",
       "      (gcn_1): ChainableGCNConv(64, 7)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GCN(**model_params).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_filters': 64,\n",
       " 'lr': '1e-2',\n",
       " 'weight_decay': '1e-3',\n",
       " 'patience': 200,\n",
       " 'max_epochs': 3000,\n",
       " 'n_features': 2879,\n",
       " 'n_classes': 7}"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_obj = torch_geometric.datasets.CitationFull(root=dataset_root, name=dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0x53dee31560c809b4'"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_tr = dataset.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "training_split = default_dataset_configs.get(\"training_split\")\n",
    "validation_split = default_dataset_configs.get(\"validation_split\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'training_split': 20,\n",
       " 'training_split_type': 'stratified',\n",
       " 'validation_split': 20,\n",
       " 'validation_split_type': 'stratified'}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "default_dataset_configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
