{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7e15fe89-cce6-4a98-b84e-be80f2427001",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(K=3, T=1, beta=1.0, cached=False, cpu=False, data_dir='../data', dataset='elliptic', device=0, directed=False, display_step=1, dropout=0.0, epochs=200, gat_heads=2, gnn='gcn', gpr_alpha=0.1, hidden_channels=32, lp_alpha=0.1, lr=0.01, lr_a=0.005, max_iter=10, method='erm', no_bn=False, noise=1.0, num_layers=2, num_sample=5, r=0.2, rocauc=False, runs=5, sub_dataset='', temp=1.0, weight_decay=0.001)\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.utils import to_undirected\n",
    "from torch_scatter import scatter\n",
    "\n",
    "from logger import Logger, SimpleLogger\n",
    "from dataset import load_nc_dataset\n",
    "from data_utils import normalize, gen_normalized_adjs, evaluate, evaluate_whole_graph, eval_acc, eval_rocauc, eval_f1, to_sparse_tensor, load_fixed_splits\n",
    "from parse import parse_method_base, parse_method_ours, parse_method_gstopr, parser_add_main_args\n",
    "\n",
    "# NOTE: for consistent data splits, see data_utils.rand_train_test_idx\n",
    "def fix_seed(seed):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "fix_seed(0)\n",
    "\n",
    "### Parse args ###\n",
    "parser = argparse.ArgumentParser(description='General Training Pipeline')\n",
    "parser_add_main_args(parser)\n",
    "# GSTOPR config\n",
    "parser.add_argument('--r', default=0.2, type=float, help='selected ratio')\n",
    "\n",
    "parser.add_argument('--noise', default=1., type=float, help='gumbel noise')\n",
    "parser.add_argument('--temp', default=1., type=float, help='sinkhorn temperature')\n",
    "parser.add_argument('--max_iter', default=10, type=int, help='sinkhorn max iter')\n",
    "args = parser.parse_args(args=[])\n",
    "print(args)\n",
    "\n",
    "device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3dbb6924-6c85-403e-97a9-442e5c3bf6f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train num nodes 6048 | num classes 2 | num node feats 165\n",
      "Val num nodes 2047 | num classes 2 | num node feats 165\n",
      "Test 0 num nodes 3385 | num classes 2 | num node feats 165\n",
      "Test 1 num nodes 1976 | num classes 2 | num node feats 165\n",
      "Test 2 num nodes 3506 | num classes 2 | num node feats 165\n",
      "Test 3 num nodes 4291 | num classes 2 | num node feats 165\n",
      "Test 4 num nodes 3537 | num classes 2 | num node feats 165\n",
      "Test 5 num nodes 5894 | num classes 2 | num node feats 165\n",
      "Test 6 num nodes 4165 | num classes 2 | num node feats 165\n",
      "Test 7 num nodes 4592 | num classes 2 | num node feats 165\n",
      "Test 8 num nodes 2314 | num classes 2 | num node feats 165\n",
      "Test 9 num nodes 2523 | num classes 2 | num node feats 165\n",
      "Test 10 num nodes 1089 | num classes 2 | num node feats 165\n",
      "Test 11 num nodes 1653 | num classes 2 | num node feats 165\n",
      "Test 12 num nodes 4275 | num classes 2 | num node feats 165\n",
      "Test 13 num nodes 2483 | num classes 2 | num node feats 165\n",
      "Test 14 num nodes 2816 | num classes 2 | num node feats 165\n",
      "Test 15 num nodes 4525 | num classes 2 | num node feats 165\n",
      "Test 16 num nodes 3151 | num classes 2 | num node feats 165\n",
      "Test 17 num nodes 2486 | num classes 2 | num node feats 165\n",
      "Test 18 num nodes 5507 | num classes 2 | num node feats 165\n",
      "Test 19 num nodes 6393 | num classes 2 | num node feats 165\n",
      "Test 20 num nodes 3306 | num classes 2 | num node feats 165\n",
      "Test 21 num nodes 2891 | num classes 2 | num node feats 165\n",
      "Test 22 num nodes 2760 | num classes 2 | num node feats 165\n",
      "Test 23 num nodes 4481 | num classes 2 | num node feats 165\n",
      "Test 24 num nodes 5342 | num classes 2 | num node feats 165\n",
      "Test 25 num nodes 7140 | num classes 2 | num node feats 165\n",
      "Test 26 num nodes 5063 | num classes 2 | num node feats 165\n",
      "Test 27 num nodes 4975 | num classes 2 | num node feats 165\n",
      "Test 28 num nodes 5598 | num classes 2 | num node feats 165\n",
      "Test 29 num nodes 3519 | num classes 2 | num node feats 165\n",
      "Test 30 num nodes 5121 | num classes 2 | num node feats 165\n",
      "Test 31 num nodes 2954 | num classes 2 | num node feats 165\n",
      "Test 32 num nodes 2454 | num classes 2 | num node feats 165\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/nas/home/dingfangyu/projs/GraphOOD-EERM/temp_elliptic/dataset.py:75: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352428622/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
      "  edge_index = torch.tensor(A.nonzero(), dtype=torch.long)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def get_dataset(dataset, sub_dataset=None):\n",
    "    ### Load and preprocess data ###\n",
    "    if dataset == 'elliptic':\n",
    "        dataset = load_nc_dataset(args.data_dir, 'elliptic', sub_dataset)\n",
    "    else:\n",
    "        raise ValueError('Invalid dataname')\n",
    "\n",
    "    if len(dataset.label.shape) == 1:\n",
    "        dataset.label = dataset.label.unsqueeze(1)\n",
    "\n",
    "    dataset.n = dataset.graph['num_nodes']\n",
    "    dataset.c = max(dataset.label.max().item() + 1, dataset.label.shape[1])\n",
    "    dataset.d = dataset.graph['node_feat'].shape[1]\n",
    "\n",
    "    dataset.graph['edge_index'], dataset.graph['node_feat'] = \\\n",
    "        dataset.graph['edge_index'], dataset.graph['node_feat']\n",
    "\n",
    "    return dataset\n",
    "\n",
    "if args.dataset == 'elliptic':\n",
    "    tr_subs, val_subs, te_subs = [i for i in range(6, 11)], [i for i in range(11, 16)], [i for i in range(16, 49)]\n",
    "    datasets_tr = [get_dataset(dataset='elliptic', sub_dataset=tr_subs[i]) for i in range(len(tr_subs))]\n",
    "    datasets_val = [get_dataset(dataset='elliptic', sub_dataset=val_subs[i]) for i in range(len(val_subs))]\n",
    "    datasets_te = [get_dataset(dataset='elliptic', sub_dataset=te_subs[i]) for i in range(len(te_subs))]\n",
    "else:\n",
    "    raise ValueError('Invalid dataname')\n",
    "\n",
    "dataset_tr = datasets_tr[0]\n",
    "dataset_val = datasets_val[0]\n",
    "print(f\"Train num nodes {dataset_tr.n} | num classes {dataset_tr.c} | num node feats {dataset_tr.d}\")\n",
    "print(f\"Val num nodes {dataset_val.n} | num classes {dataset_val.c} | num node feats {dataset_val.d}\")\n",
    "for i in range(len(te_subs)):\n",
    "    dataset_te = datasets_te[i]\n",
    "    print(f\"Test {i} num nodes {dataset_te.n} | num classes {dataset_te.c} | num node feats {dataset_te.d}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "92f41f23-07d1-472b-91d5-5c5c7a1f6514",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "3\n",
      "MODEL: Model(\n",
      "  (gnn): GCN(\n",
      "    (convs): ModuleList(\n",
      "      (0): GCNConv(165, 32)\n",
      "      (1): GCNConv(32, 2)\n",
      "    )\n",
      "    (bns): ModuleList(\n",
      "      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (gl): ModuleList(\n",
      "    (0): Graph_Editer()\n",
      "    (1): Graph_Editer()\n",
      "    (2): Graph_Editer()\n",
      "    (3): Graph_Editer()\n",
      "    (4): Graph_Editer()\n",
      "  )\n",
      ")\n",
      "DATASET: elliptic\n"
     ]
    }
   ],
   "source": [
    "\n",
    "### Load method ###\n",
    "if args.method == 'erm':\n",
    "    print(1)\n",
    "    model = parse_method_base(args, datasets_tr, device)\n",
    "elif args.method == 'gstopr':\n",
    "    print(2)\n",
    "    model = parse_method_gstopr(args, datasets_tr, device)\n",
    "else:\n",
    "    print(3)\n",
    "    model = parse_method_ours(args, datasets_tr, device)\n",
    "\n",
    "\n",
    "# using rocauc as the eval function\n",
    "if args.rocauc or args.dataset in ('twitch-e', 'fb100', 'elliptic'):\n",
    "    criterion = nn.BCEWithLogitsLoss()\n",
    "    eval_func = eval_f1\n",
    "else:\n",
    "    criterion = nn.NLLLoss()\n",
    "    eval_func = eval_acc\n",
    "\n",
    "logger = Logger(args.runs, args)\n",
    "\n",
    "model.train()\n",
    "print('MODEL:', model)\n",
    "print('DATASET:', args.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "718d21f2-32f7-4646-9e3a-50baaedd422f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "### Training loop ###\n",
    "for run in range(args.runs):\n",
    "    model.reset_parameters()\n",
    "    if args.method in ['erm', 'gstopr']:\n",
    "        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n",
    "    elif args.method == 'eerm':\n",
    "        optimizer_gnn = torch.optim.AdamW(model.gnn.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n",
    "        optimizer_aug = torch.optim.AdamW(model.gl.parameters(), lr=args.lr_a)\n",
    "    best_val = float('-inf')\n",
    "    for epoch in range(args.epochs):\n",
    "        model.train()\n",
    "        if args.method in ['erm', 'gstopr']:\n",
    "            optimizer.zero_grad()\n",
    "            loss = model(datasets_tr, criterion)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        elif args.method == 'eerm':\n",
    "            for m in range(args.T):\n",
    "                Var, Mean, Log_p = model(datasets_tr, criterion)\n",
    "                outer_loss = Var + args.beta * Mean\n",
    "                reward = Var.detach()\n",
    "                inner_loss = - reward * Log_p\n",
    "                if m == 0:\n",
    "                    optimizer_gnn.zero_grad()\n",
    "                    outer_loss.backward()\n",
    "                    optimizer_gnn.step()\n",
    "                optimizer_aug.zero_grad()\n",
    "                inner_loss.backward()\n",
    "                optimizer_aug.step()\n",
    "\n",
    "        accs, test_outs = evaluate_whole_graph(args, model, datasets_tr, datasets_val, datasets_te, eval_func)\n",
    "        logger.add_result(run, accs)\n",
    "\n",
    "        if epoch % args.display_step == 0:\n",
    "            if args.method in ['erm', 'gstopr']:\n",
    "                print(f'Epoch: {epoch:02d}, '\n",
    "                  f'Loss: {loss:.4f}, '\n",
    "                  f'Train: {100 * accs[0]:.2f}%, '\n",
    "                  f'Valid: {100 * accs[1]:.2f}%, ')\n",
    "                test_info = ''\n",
    "                for test_acc in accs[2:]:\n",
    "                    test_info += f'Test: {100 * test_acc:.2f}% '\n",
    "                print(test_info)\n",
    "            elif args.method == 'eerm':\n",
    "                print(f'Epoch: {epoch:02d}, '\n",
    "                      f'Mean Loss: {Mean:.4f}, '\n",
    "                      f'Var Loss: {Var:.4f}, '\n",
    "                      f'Train: {100 * accs[0]:.2f}%, '\n",
    "                      f'Valid: {100 * accs[1]:.2f}%, ')\n",
    "                test_info = ''\n",
    "                for test_acc in accs[2:]:\n",
    "                    test_info += f'Test: {100 * test_acc:.2f}% '\n",
    "                print(test_info)\n",
    "\n",
    "    logger.print_statistics(run)\n",
    "\n",
    "### Save results ###\n",
    "results = logger.print_statistics()\n",
    "filename = f'./results/{args.dataset}.csv'\n",
    "print(f\"Saving results to {filename}\")\n",
    "with open(f\"{filename}\", 'a+') as write_obj:\n",
    "    log = f\"{args.method},\" + (f\"r={args.r},g={args.noise},\" if args.method == 'gstopr' else '') + f\"{args.gnn},\"\n",
    "    torch.save(results, \"./results/\" + log.replace(',', '-') + '.pt')\n",
    "    for i in range(results.shape[1]):\n",
    "        r = results[:, i]\n",
    "        log += f\"{r.mean():.3f} ± {r.std():.3f},\"\n",
    "    write_obj.write(log + f\"\\n\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cu102",
   "language": "python",
   "name": "cu102"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
