{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "67348db0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhaogroup/anaconda3/envs/GNN/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:The OGB package is out of date. Your version is 1.3.3, while the latest version is 1.3.5.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import argparse\n",
    "import torch.nn.functional as F\n",
    "from ogb.nodeproppred import Evaluator\n",
    "from utils import set_seed, load_data, get_model\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ae279d78-c17d-4589-9018-847f31800eaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_args():\n",
    "    parser = argparse.ArgumentParser(description='GCN')\n",
    "    parser.add_argument('--repetitions', type=int, default=10)\n",
    "    parser.add_argument('--random_seed', type=int, default=1000)\n",
    "    parser.add_argument('--dataset', type=str, default='Citeseer-adv')\n",
    "    parser.add_argument('--device', type=int, default=1)\n",
    "    parser.add_argument('--num_layers', type=int, default=2)\n",
    "    parser.add_argument('--weight_decay', type=float, default=5e-4)\n",
    "    parser.add_argument('--lr', type=float, default=0.01)\n",
    "    parser.add_argument('--dropout', type=float, default=0.1)\n",
    "    parser.add_argument('--type_model', type=str, default='ASGNN')   \n",
    "    parser.add_argument('--alpha', type=float, default=0.1) \n",
    "    parser.add_argument('--lamda', type=float, default=0.5)\n",
    "    parser.add_argument('--transductive', type=bool, default=True)\n",
    "    parser.add_argument('--ptb_rate', type=float, default=5.0)\n",
    "    parser.add_argument('--attack', type=str, default='nettack')\n",
    "    args = parser.parse_args(args=[])\n",
    "    return args\n",
    "args = parse_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "83279eac-adb0-4b92-be76-a2fa512fb1ae",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(alpha=0.1, attack='nettack', dataset='Citeseer-adv', device=1, dropout=0.1, lamda=0.5, lr=0.01, num_layers=2, ptb_rate=5.0, random_seed=1000, repetitions=10, transductive=True, type_model='ASGNN', weight_decay=0.0005)\n",
      "Repetition <0>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.7761\n",
      "Repetition <1>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.8209\n",
      "Repetition <2>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.8209\n",
      "Repetition <3>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.8060\n",
      "Repetition <4>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.7910\n",
      "Repetition <5>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.8060\n",
      "Repetition <6>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.7761\n",
      "Repetition <7>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.7910\n",
      "Repetition <8>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.7910\n",
      "Repetition <9>\n",
      "Loading citeseer dataset...\n",
      "Selecting 1 largest connected components\n",
      "Loading citeseer dataset perturbed by 5.0 nettack...\n",
      "test_acc:0.8358\n",
      "final mean and std of test acc with <10> runs: 0.8015±0.0189\n"
     ]
    }
   ],
   "source": [
    "print(args)\n",
    "best_acc_mean = 0\n",
    "best_acc_std = 0\n",
    "device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')     \n",
    "list_test_acc = []\n",
    "list_valid_acc = []\n",
    "list_train_loss = []    \n",
    "for repetition in range(args.repetitions):\n",
    "    print(f'Repetition <{repetition}>')\n",
    "    set_seed(repetition)\n",
    "    args, data = load_data(args)                \n",
    "\n",
    "    data = data.to(device)\n",
    "    args.num_features = data.num_node_features \n",
    "    torch.cuda.empty_cache()                                       \n",
    "    model = get_model(args)               \n",
    "    model.cuda(device) \n",
    "    best_train_loss = 100.\n",
    "    best_val_loss = 100.\n",
    "    best_train_acc = 0.\n",
    "    best_val_acc = 0.        \n",
    "    best_test_acc = 0.\n",
    "    bad_counter = 0.\n",
    "    for epoch in range(args.epochs):\n",
    "        model.train()       \n",
    "        output = model(data.x, data.edge_index)           \n",
    "        loss = 0.\n",
    "        loss_train = F.nll_loss(output[data.train_mask], data.y[data.train_mask])   \n",
    "        model.optimizer.zero_grad()\n",
    "        loss_train.backward()\n",
    "        model.optimizer.step()\n",
    "\n",
    "        model.eval()\n",
    "        output = model(data.x, data.edge_index)\n",
    "        acc_train = torch.sum(torch.argmax(output, dim=1)[data.train_mask] == \n",
    "                              data.y[data.train_mask]).item() * 1.0 / data.train_mask.sum().item()\n",
    "        acc_val = torch.sum(torch.argmax(output, dim=1)[data.val_mask] == \n",
    "                              data.y[data.val_mask]).item() * 1.0 / data.val_mask.sum().item()\n",
    "        acc_test = torch.sum(torch.argmax(output, dim=1)[data.test_mask] == \n",
    "                              data.y[data.test_mask]).item() * 1.0 / data.test_mask.sum().item()\n",
    "\n",
    "        loss_val = F.nll_loss(output[data.val_mask], data.y[data.val_mask])     \n",
    "\n",
    "        if loss_val < best_val_loss:\n",
    "            best_train_loss = loss_train\n",
    "            best_val_loss = loss_val\n",
    "            best_train_acc = acc_train\n",
    "            best_val_acc = acc_val\n",
    "            best_test_acc = acc_test               \n",
    "            bad_counter = 0\n",
    "        else:\n",
    "            bad_counter += 1\n",
    "        if bad_counter == args.patience:\n",
    "            break\n",
    "    print('test_acc:{:.4f}'.format(best_test_acc))\n",
    "    list_train_loss.append(best_train_loss)\n",
    "    list_valid_acc.append(best_val_acc)\n",
    "    list_test_acc.append(best_test_acc)\n",
    "print('final mean and std of test acc with <{}> runs: {:.4f}±{:.4f}'.format(\n",
    "    args.repetitions, np.mean(list_test_acc), np.std(list_test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b010d1-9025-4ac8-9d17-5b3486a51c61",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
