{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import logging\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits, check_dataset_valid\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance, load_robust_model_instance, from_sparse_GCN, from_sparse_GPRGNN\n",
    "from utils.attack import load_attack_class, attack_storage_label, create_attack_instance\n",
    "\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "\n",
    "from robust_diffusion.attacks import create_attack\n",
    "from robust_diffusion.helper.utils import accuracy, calculate_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "\n",
    "# model_name in [\"GCN\", \"DenseGCN\", \"GAT\", \"GPRGNN\", \"DenseGPRGNN\", \"APPNP\", \"ChebNetII\", \"SoftMedian_GDC\"]\n",
    "model_name = \"GCN\"\n",
    "n_splits = 10\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "test_split = None\n",
    "test_split_type = None\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "# attack_name in [\"PRBCD\", \"LRBCD\", \"EvaAttack\", \"Evafast\", \"PGD\"] \n",
    "attack_name = \"PRBCD\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = True\n",
    "self_training = False\n",
    "robust_training = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n",
      "Loading dataset = cora_ml\n",
      "Found 10 splits!\n",
      "Loading pretrained GCN model on cora_ml dataset for 10 splits\n"
     ]
    }
   ],
   "source": [
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [\n",
    "    split_record for split_record in os.listdir(splits_root) \n",
    "    if split_record.split(\"-\")[0] == dataset_name \n",
    "    and check_dataset_valid(split_record=split_record, training_split=training_split,\n",
    "                            validation_split=validation_split, training_split_type=training_split_type, \n",
    "                            validation_split_type=validation_split_type, test_split=test_split, \n",
    "                            test_split_type=test_split_type, splits_root=splits_root)]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "if creating_splits > 0:\n",
    "    raise ValueError(\"Not enough splits for the dataset. Create the splits by running training scripts.\")\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits!\")\n",
    "\n",
    "print(f\"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load vanilla model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 30.14it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.8201),\n",
       " tensor(0.0161),\n",
       " [0.8060200668896321,\n",
       "  0.8294314381270903,\n",
       "  0.8361204013377926,\n",
       "  0.8193979933110368,\n",
       "  0.8260869565217391,\n",
       "  0.842809364548495,\n",
       "  0.8093645484949833,\n",
       "  0.8327759197324415,\n",
       "  0.7926421404682275,\n",
       "  0.8060200668896321])"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clean_accs = []\n",
    "for split_file in tqdm(dataset_splits[:10]):\n",
    "    split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "    data = load_dataset_splits(\n",
    "        dataset_name, split_code, inductive=inductive, \n",
    "        dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "    training_attr = data[\"training_attr\"]\n",
    "    training_adj = data[\"training_adj\"]\n",
    "    validation_attr = data[\"validation_attr\"]\n",
    "    validation_adj = data[\"validation_adj\"]\n",
    "    labels = data[\"labels\"]\n",
    "    training_idx = data[\"training_idx\"]\n",
    "    validation_idx = data[\"validation_idx\"]\n",
    "    test_attr = data[\"test_attr\"]\n",
    "    test_adj = data[\"test_adj\"]\n",
    "    unlabeled_mask = data[\"unlabeled_mask\"]\n",
    "    test_mask = data[\"test_mask\"]\n",
    "    dataset_info = data[\"dataset_info\"]\n",
    "    split_name = data[\"split_name\"]\n",
    "    data_config = data[\"config\"]\n",
    "\n",
    "    try:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, \n",
    "            test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "            split_name=split_name, dataset_info=dataset_info, \n",
    "            inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    except FileNotFoundError as e:\n",
    "        print(e)\n",
    "        raise ValueError(\"Model not found. Run training scripts to train the model.\")\n",
    "\n",
    "    model = model_instance[\"model\"]\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_params = model_instance[\"model_params\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    clean_accs.append(acc)\n",
    "torch.mean(torch.tensor(clean_accs)), torch.std(torch.tensor(clean_accs)), clean_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['cora_ml-3c561bc6e9.pt',\n",
       " 'cora_ml-19f0927a57.pt',\n",
       " 'cora_ml-28286c92d8.pt',\n",
       " 'cora_ml-f98474d283.pt',\n",
       " 'cora_ml-61803e13b4.pt',\n",
       " 'cora_ml-0c1e4dfae9.pt',\n",
       " 'cora_ml-dcba8752e3.pt',\n",
       " 'cora_ml-d963f2bb9a.pt',\n",
       " 'cora_ml-3b080d0e45.pt',\n",
       " 'cora_ml-b170d5187b.pt']"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load robust model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 38.04it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.8184),\n",
       " tensor(0.0177),\n",
       " [0.8127090301003345,\n",
       "  0.8093645484949833,\n",
       "  0.802675585284281,\n",
       "  0.8160535117056856,\n",
       "  0.8528428093645485,\n",
       "  0.842809364548495,\n",
       "  0.8193979933110368,\n",
       "  0.8260869565217391,\n",
       "  0.802675585284281,\n",
       "  0.7993311036789298])"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "robust_accs=[]\n",
    "for split_file in tqdm(dataset_splits[:10]):\n",
    "        split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "        data = load_dataset_splits(\n",
    "                dataset_name, split_code, inductive=inductive, \n",
    "                dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "        training_attr = data[\"training_attr\"]\n",
    "        training_adj = data[\"training_adj\"]\n",
    "        validation_attr = data[\"validation_attr\"]\n",
    "        validation_adj = data[\"validation_adj\"]\n",
    "        labels = data[\"labels\"]\n",
    "        training_idx = data[\"training_idx\"]\n",
    "        validation_idx = data[\"validation_idx\"]\n",
    "        test_attr = data[\"test_attr\"]\n",
    "        test_adj = data[\"test_adj\"]\n",
    "        unlabeled_mask = data[\"unlabeled_mask\"]\n",
    "        test_mask = data[\"test_mask\"]\n",
    "        dataset_info = data[\"dataset_info\"]\n",
    "        split_name = data[\"split_name\"]\n",
    "        data_config = data[\"config\"]\n",
    "\n",
    "        model_instance = load_robust_model_instance(\n",
    "                model_name=model_name, model_params=model_params, \n",
    "                dataset_info=dataset_info, \n",
    "                test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,\n",
    "                split_name=split_name, inductive=inductive,\n",
    "                models_root=models_root, self_training=True, robust_training=True, attack_name=attack_name, robust_epsilon=0.2,\n",
    "                default_model_configs=default_model_configs, suffix='', device=device)\n",
    "        robust_accs.append(model_instance[\"clean_accuracy\"])\n",
    "\n",
    "mean_clean_acc = torch.mean(torch.tensor(robust_accs))\n",
    "std_clean_acc = torch.std(torch.tensor(robust_accs))\n",
    "mean_clean_acc, std_clean_acc, robust_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
