{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n",
      "2024-04-09 15:06:09,770\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "WARNING:evotorch:The logger is already configured. The default configuration will not be applied. Call `set_default_logger_config` with `override=True` to override the current configuration.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance\n",
    "from utils.attack import load_attack_class\n",
    "\n",
    "from robust_diffusion.data import SparseGraph\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "from robust_diffusion.helper import utils as robust_utils\n",
    "from robust_diffusion.train import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "model_name = \"GCN\"\n",
    "recreate_splits = False\n",
    "n_splits = 50\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "test_split = None\n",
    "test_split_type = None\n",
    "# TODO: add unlabeled split\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "attack_name = \"PRBCD\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "try:\n",
    "    dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name]\n",
    "except FileNotFoundError:\n",
    "    dataset_splits = []\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits, creating {creating_splits} more splits\")\n",
    "for i in tqdm(range(creating_splits)):\n",
    "    torch.cuda.empty_cache()\n",
    "    data = make_dataset_splits(dataset_name=dataset_name, \n",
    "                            training_split=training_split, validation_split=validation_split, \n",
    "                            training_split_type=training_split_type, validation_split_type=validation_split_type, \n",
    "                            inductive=inductive, \n",
    "                            default_dataset_configs=default_dataset_configs, dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name][:n_splits]\n",
    "print(f\"Training {model_name} model on {dataset_name} dataset for {n_splits} splits\")\n",
    "\n",
    "accs = []\n",
    "for split_file in tqdm(dataset_splits):\n",
    "    split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "    data = load_dataset_splits(\n",
    "        dataset_name, split_code, inductive=inductive, \n",
    "        dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "    training_attr = data[\"training_attr\"]\n",
    "    training_adj = data[\"training_adj\"]\n",
    "    labels = data[\"labels\"]\n",
    "    training_idx = data[\"training_idx\"]\n",
    "    validation_idx = data[\"validation_idx\"]\n",
    "    test_attr = data[\"test_attr\"]\n",
    "    test_adj = data[\"test_adj\"]\n",
    "    test_mask = data[\"test_mask\"]\n",
    "    dataset_info = data[\"dataset_info\"]\n",
    "    split_name = data[\"split_name\"]\n",
    "\n",
    "    try:\n",
    "        model_instance = load_model_instance(\n",
    "            model_name=model_name, model_params=model_params, \n",
    "            test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, split_name=split_name, dataset_info=dataset_info, inductive=inductive,\n",
    "            models_root=models_root,\n",
    "            default_model_configs=default_model_configs, device=device)\n",
    "    except FileNotFoundError as e:\n",
    "        print(e)\n",
    "        print(\"Creating model from scratch\")\n",
    "        model_instance = create_model_instance(\n",
    "            model_name=model_name, model_params=model_params, dataset_info=dataset_info, \n",
    "            training_attr=training_attr, training_adj=training_adj, labels=labels, training_idx=training_idx, validation_idx=validation_idx,\n",
    "            test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, inductive=inductive, split_name=split_name,\n",
    "            models_root=models_root, \n",
    "            default_model_configs=default_model_configs, \n",
    "            device=device)\n",
    "\n",
    "    model = model_instance[\"model\"]\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_params = model_instance[\"model_params\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    accs.append(acc)\n",
    "\n",
    "acc_mean = torch.mean(torch.tensor(accs))\n",
    "acc_std = torch.std(torch.tensor(accs))\n",
    "\n",
    "print(f\"Mean accuracy: {acc_mean}, std: {acc_std}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
