{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import argparse\n",
    "import multiprocessing as mp\n",
    "\n",
    "calculate_grad_vars = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.argv=['']; del sys\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Dataset arguments\n",
    "\"\"\"\n",
    "parser = argparse.ArgumentParser(\n",
    "    description='Training GCN on Large-scale Graph Datasets')\n",
    "parser.add_argument('--dataset', type=str, default='reddit',\n",
    "                    help='Dataset name: pubmed/flickr/reddit/ppi-large')\n",
    "parser.add_argument('--sample_method', type=str, default='full',\n",
    "                    help='Sampled Algorithms: full/ladies/fastgcn/graphsage/exact/graphsaint/vrgcn')\n",
    "parser.add_argument('--nhid', type=int, default=256,\n",
    "                    help='Hidden state dimension')\n",
    "parser.add_argument('--epoch_num', type=int, default=200,\n",
    "                    help='Number of Epoch')\n",
    "parser.add_argument('--pool_num', type=int, default=10,\n",
    "                    help='Number of Pool')\n",
    "parser.add_argument('--batch_num', type=int, default=10,\n",
    "                    help='Maximum Batch Number')\n",
    "parser.add_argument('--batch_size', type=int, default=512,\n",
    "                    help='size of output node in a batch')\n",
    "parser.add_argument('--large_batch_size', type=int, default=81920,\n",
    "                    help='size of output node in a batch')\n",
    "parser.add_argument('--n_layers', type=int, default=2,\n",
    "                    help='Number of GCN layers')\n",
    "parser.add_argument('--n_stops', type=int, default=200,\n",
    "                    help='Early stops')\n",
    "parser.add_argument('--dropout', type=float, default=0,\n",
    "                    help='Dropout rate')\n",
    "parser.add_argument('--cuda', type=int, default=1,\n",
    "                    help='Avaiable GPU ID')\n",
    "parser.add_argument('--save_prefix', type=str, default='exps',\n",
    "                    help='Save file prefix')\n",
    "parser.add_argument('--run_options', type=str, default='True-True-True',\n",
    "                    help='Run Vanilla? Zeroth? Doubly?')\n",
    "parser.add_argument('--dist_bound', type=float, default=5e-3,\n",
    "                    help='Restart if the different is large')\n",
    "parser.add_argument('--use_SGD', type=str, default='False', \n",
    "                    help='Whether use SGD?')                 \n",
    "args = parser.parse_args()\n",
    "print(args)\n",
    "\n",
    "vanilla, zeroth_order, doubly_order = [option=='True' for option in args.run_options.split('-')]\n",
    "sample_method_list = args.sample_method.split('/')\n",
    "dist_bound = args.dist_bound\n",
    "print(vanilla, zeroth_order, doubly_order, dist_bound)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Prepare devices\n",
    "\"\"\"\n",
    "if args.cuda != -1:\n",
    "    device = torch.device(\"cuda:\" + str(args.cuda))\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    \n",
    "\"\"\"\n",
    "Prepare data using multi-process\n",
    "\"\"\"\n",
    "def prepare_data(pool, sampler, process_ids, candidate_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth):\n",
    "    jobs = []\n",
    "    for _ in process_ids:\n",
    "        batch_nodes = np.random.permutation(candidate_nodes)[:args.batch_size]\n",
    "        p = pool.apply_async(sampler, args=(np.random.randint(2**32 - 1), batch_nodes,\n",
    "                                            samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth))\n",
    "        jobs.append(p)\n",
    "    return jobs\n",
    "\n",
    "lap_matrix, labels, feat_data, train_nodes, valid_nodes, test_nodes = preprocess_data(args.dataset, False)\n",
    "\n",
    "print(\"Dataset information\")\n",
    "print(lap_matrix.shape, labels.shape, feat_data.shape,\n",
    "      train_nodes.shape, valid_nodes.shape, test_nodes.shape)\n",
    "if type(feat_data) == sp.lil.lil_matrix:\n",
    "    feat_data = torch.FloatTensor(feat_data.todense()).to(device)\n",
    "else:\n",
    "    feat_data = torch.FloatTensor(feat_data).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Setup datasets and models for training (multi-class use sigmoid+binary_cross_entropy, use softmax+nll_loss otherwise)\n",
    "\"\"\"\n",
    "\n",
    "if args.dataset in ['cora', 'citeseer', 'pubmed', 'flickr', 'reddit']:\n",
    "    from model import GCN\n",
    "    from optimizers import sgcn_first, sgcn_zeroth, sgcn_doubly, sgd_step, full_step, VRGCN_step, VRGCN_doubly\n",
    "    from optimizers import ForwardWrapper, VRGCNWrapper, package_mxl\n",
    "    from samplers import fastgcn_sampler, ladies_sampler, graphsage_sampler, exact_sampler, full_batch_sampler, graphsaint_sampler, vrgcn_sampler\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "    num_classes = labels.max().item()+1\n",
    "elif args.dataset in ['ppi', 'ppi-large', 'amazon', 'yelp']:\n",
    "    from model_mc import GCN\n",
    "    from optimizers_mc import sgcn_first, sgcn_zeroth, sgcn_doubly, sgd_step, full_step, VRGCN_step, VRGCN_doubly\n",
    "    from optimizers_mc import ForwardWrapper, VRGCNWrapper, package_mxl\n",
    "    from samplers_sage_support import fastgcn_sampler, ladies_sampler, graphsage_sampler, exact_sampler, full_batch_sampler, graphsaint_sampler, vrgcn_sampler\n",
    "    labels = torch.FloatTensor(labels).to(device)\n",
    "    num_classes = labels.shape[1]\n",
    "    \n",
    "\n",
    "def calculate_grad_variance(net, feat_data, labels, train_nodes, adjs_full):\n",
    "    net_grads = []\n",
    "    for p_net in net.parameters():\n",
    "        net_grads.append(p_net.grad.data)\n",
    "    clone_net = copy.deepcopy(net)\n",
    "    _, _ = clone_net.calculate_loss_grad(\n",
    "        feat_data, adjs_full, labels, train_nodes)\n",
    "\n",
    "    clone_net_grad = []\n",
    "    for p_net in clone_net.parameters():\n",
    "        clone_net_grad.append(p_net.grad.data)\n",
    "    del clone_net\n",
    "\n",
    "    variance = 0.0\n",
    "    for g1, g2 in zip(net_grads, clone_net_grad):\n",
    "        variance += (g1-g2).norm(2) ** 2\n",
    "    variance = torch.sqrt(variance)\n",
    "    return variance\n",
    "\n",
    "def sample_large_batch(args, train_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth):\n",
    "    batch_nodes = np.random.permutation(train_nodes)[:args.large_batch_size]\n",
    "    adjs, input_nodes, output_nodes, sampled_nodes = exact_sampler(np.random.randint(2**32 - 1), \n",
    "                                                             batch_nodes, samp_num_list, num_nodes, lap_matrix, lap_matrix_sq, depth)\n",
    "    return adjs, input_nodes, output_nodes, sampled_nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This is a zeroth-order and first-order variance reduced version of SGCN++\n",
    "\"\"\"\n",
    "\n",
    "def algorithm_sgcn_doubly(feat_data, labels, lap_matrix,\n",
    "                          train_nodes, valid_nodes, test_nodes,  \n",
    "                          args, device, calculate_grad_vars=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "\n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "    \n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "\n",
    "    forward_wrapper = ForwardWrapper(\n",
    "        len(feat_data), args.nhid, args.n_layers, num_classes)\n",
    "\n",
    "    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "    \n",
    "    \n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        # create large batch\n",
    "        large_batch_sample_time_start = time.perf_counter()\n",
    "        \n",
    "        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, \n",
    "                                                                                                     large_samp_num_list, len(feat_data),\n",
    "                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        large_batch_transfer_time_start = time.perf_counter()\n",
    "        large_adjs = package_mxl(large_adjs, device)\n",
    "        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start\n",
    "        \n",
    "        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start\n",
    "        \n",
    "        \n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        # prepare next epoch train data\n",
    "        mini_batch_sample_time_start = time.perf_counter()\n",
    "        pool = mp.Pool(args.pool_num)\n",
    "        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),\n",
    "                            lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        # fetch train data\n",
    "        train_data = [job.get() for job in jobs]\n",
    "        pool.close()\n",
    "        pool.join()\n",
    "        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "\n",
    "        inner_loop_num = args.batch_num\n",
    "        calculate_grad_vars = calculate_grad_vars and epoch<20\n",
    "        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_doubly(susage, optimizer, feat_data, labels,\n",
    "                                                                        train_nodes, valid_nodes, \n",
    "                                                                        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,\n",
    "                                                                        train_data, inner_loop_num, forward_wrapper, device, dist_bound=dist_bound, #2e-4\n",
    "                                                                        calculate_grad_vars=calculate_grad_vars)\n",
    "        compute_time = time_counter['compute_time']\n",
    "        transfer_time = time_counter['transfer_time']\n",
    "        \n",
    "        epoch_time_counter = {\n",
    "            'large_batch_sample_time': large_batch_sample_time,\n",
    "            'large_batch_transfer_time': large_batch_transfer_time,\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "        wall_clock_time.append(epoch_time_counter)\n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        \n",
    "        # calculate validate loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "\n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "        \n",
    "        \n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "    \n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\"\"\"\n",
    "This is just an unchanged SGCN \n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def sgcn(feat_data, labels, lap_matrix, \n",
    "         train_nodes, valid_nodes, test_nodes,  \n",
    "         args, device, calculate_grad_vars=False, full_batch=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "\n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "\n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "\n",
    "    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "\n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        inner_loop_num = args.batch_num\n",
    "\n",
    "        # it can also run full-batch GD by ignoring all the samplings\n",
    "        \n",
    "        if full_batch:\n",
    "            mini_batch_sample_time = 0\n",
    "            cur_train_loss, cur_train_loss_all, grad_variance, time_counter = full_step(susage, optimizer, feat_data, labels,\n",
    "                                              train_nodes, valid_nodes,\n",
    "                                              adjs_full, inner_loop_num, device, \n",
    "                                              calculate_grad_vars=calculate_grad_vars)\n",
    "        else:\n",
    "            # prepare next epoch train data\n",
    "            mini_batch_sample_time_start = time.perf_counter()\n",
    "            pool = mp.Pool(args.pool_num)\n",
    "            jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),\n",
    "                                lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "            # fetch train data\n",
    "            train_data = [job.get() for job in jobs]\n",
    "            pool.close()\n",
    "            pool.join()\n",
    "            mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "            \n",
    "            cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgd_step(susage, optimizer, feat_data, labels,\n",
    "                                              train_nodes, valid_nodes,\n",
    "                                              adjs_full, train_data, inner_loop_num, device, \n",
    "                                              calculate_grad_vars=calculate_grad_vars)\n",
    "        \n",
    "        epoch_time_counter = {\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "        \n",
    "        wall_clock_time.append(epoch_time_counter)\n",
    "        \n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        \n",
    "        # calculate test loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "            \n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "        \n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time\n",
    "\n",
    "def algorithm_vrgcn(feat_data, labels, lap_matrix, \n",
    "                    train_nodes, valid_nodes, test_nodes,  \n",
    "                    args, device, calculate_grad_vars=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "    \n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "\n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "\n",
    "    forward_wrapper = VRGCNWrapper(\n",
    "        len(feat_data), args.nhid, args.n_layers, num_classes)\n",
    "\n",
    "    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        # prepare next epoch train data\n",
    "        mini_batch_sample_time_start = time.perf_counter()\n",
    "        start_time = time.time()\n",
    "        pool = mp.Pool(args.pool_num)\n",
    "        jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),\n",
    "                            lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        # fetch train data\n",
    "        train_data = [job.get() for job in jobs]\n",
    "        pool.close()\n",
    "        pool.join()\n",
    "        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "\n",
    "        inner_loop_num = args.batch_num\n",
    "        calculate_grad_vars = calculate_grad_vars and epoch<20\n",
    "        \n",
    "        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = VRGCN_step(susage, optimizer, feat_data, labels,\n",
    "                                                                       train_nodes, valid_nodes, adjs_full,\n",
    "                                                                       train_data, inner_loop_num, forward_wrapper, device,\n",
    "                                                                       calculate_grad_vars=calculate_grad_vars)\n",
    "        compute_time = time_counter['compute_time']\n",
    "        transfer_time = time_counter['transfer_time']\n",
    "        \n",
    "        epoch_time_counter = {\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "        wall_clock_time.append(epoch_time_counter)\n",
    "        \n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        # calculate validate loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "\n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "\n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time\n",
    "\n",
    "def algorithm_vrgcn_doubly(feat_data, labels, lap_matrix, \n",
    "                          train_nodes, valid_nodes, test_nodes,  \n",
    "                          args, device, calculate_grad_vars=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "\n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "\n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "    \n",
    "    forward_wrapper = VRGCNWrapper(len(feat_data), args.nhid, args.n_layers, num_classes)\n",
    "    \n",
    "    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        # prepare next epoch train data\n",
    "        mini_batch_sample_time_start = time.perf_counter()\n",
    "        pool = mp.Pool(args.pool_num)\n",
    "        jobs = prepare_data(pool, sampler, process_ids, train_nodes, samp_num_list, len(feat_data),\n",
    "                            lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        # fetch train data\n",
    "        train_data = [job.get() for job in jobs]\n",
    "        pool.close()\n",
    "        pool.join()\n",
    "        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "\n",
    "        inner_loop_num = args.batch_num\n",
    "        calculate_grad_vars = calculate_grad_vars and epoch<20\n",
    "        \n",
    "        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = VRGCN_doubly(susage, optimizer, feat_data, labels,\n",
    "                                                         train_nodes, valid_nodes,\n",
    "                                                         adjs_full, train_data, inner_loop_num, forward_wrapper, device,\n",
    "                                                         calculate_grad_vars=calculate_grad_vars)\n",
    "        compute_time = time_counter['compute_time']\n",
    "        transfer_time = time_counter['transfer_time']\n",
    "\n",
    "        epoch_time_counter = {\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "\n",
    "        wall_clock_time.append(epoch_time_counter)\n",
    "        \n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        # calculate test loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "\n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "\n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "\n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time\n",
    "\n",
    "\"\"\"\n",
    "This is a zeroth-order variance reduced version of SGCN+\n",
    "\"\"\"\n",
    "\n",
    "def algorithm_sgcn_zeroth(feat_data, labels, lap_matrix, \n",
    "                          train_nodes, valid_nodes, test_nodes,  \n",
    "                          args, device, calculate_grad_vars=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "\n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "\n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "\n",
    "    # this stupid wrapper is only used for sgcn++\n",
    "    forward_wrapper = ForwardWrapper(\n",
    "        len(feat_data), args.nhid, args.n_layers, num_classes)\n",
    "\n",
    "    optimizer = optim.Adam(\n",
    "        filter(lambda p: p.requires_grad, susage.parameters()))\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        # create large batch\n",
    "        large_batch_sample_time_start = time.perf_counter()\n",
    "        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, \n",
    "                                                                                                     large_samp_num_list, len(feat_data),\n",
    "                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start\n",
    "\n",
    "        large_batch_transfer_time_start = time.perf_counter()\n",
    "        large_adjs = package_mxl(large_adjs, device)\n",
    "        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start\n",
    "        \n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        # prepare next epoch train data\n",
    "        mini_batch_sample_time_start = time.perf_counter()\n",
    "        start_time = time.time()\n",
    "        pool = mp.Pool(args.pool_num)\n",
    "        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),\n",
    "                            lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        # fetch train data\n",
    "        train_data = [job.get() for job in jobs]\n",
    "        pool.close()\n",
    "        pool.join()\n",
    "        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "\n",
    "        inner_loop_num = args.batch_num\n",
    "        calculate_grad_vars = calculate_grad_vars and epoch<20\n",
    "        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_zeroth(susage, optimizer, feat_data, labels,\n",
    "                                                                        train_nodes, valid_nodes, \n",
    "                                                                        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,\n",
    "                                                                        train_data, inner_loop_num, forward_wrapper, device, dist_bound=dist_bound,\n",
    "                                                                        calculate_grad_vars=calculate_grad_vars)\n",
    "        compute_time = time_counter['compute_time']\n",
    "        transfer_time = time_counter['transfer_time']\n",
    "        \n",
    "        epoch_time_counter = {\n",
    "            'large_batch_sample_time': large_batch_sample_time,\n",
    "            'large_batch_transfer_time': large_batch_transfer_time,\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "        wall_clock_time.append(epoch_time_counter)\n",
    "        \n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        # calculate validate loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "\n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "        \n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "\n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time\n",
    "\n",
    "\"\"\"\n",
    "This is a first-order variance reduced version of SGCN++\n",
    "\"\"\"\n",
    "\n",
    "def algorithm_sgcn_first(feat_data, labels, lap_matrix, \n",
    "                         train_nodes, valid_nodes, test_nodes,  \n",
    "                         args, device, calculate_grad_vars=False):\n",
    "    memory_allocated, max_memory_allocated = [], []\n",
    "    \n",
    "    # use multiprocess sample data\n",
    "    process_ids = np.arange(args.batch_num)\n",
    "    lap_matrix_sq = lap_matrix.multiply(lap_matrix)\n",
    "\n",
    "    susage = GCN(nfeat=feat_data.shape[1], nhid=args.nhid, num_classes=num_classes,\n",
    "                 layers=args.n_layers, dropout=args.dropout).to(device)\n",
    "    susage.to(device)\n",
    "\n",
    "    print(susage)\n",
    "\n",
    "    adjs_full, input_nodes_full, sampled_nodes_full = full_batch_sampler(\n",
    "        train_nodes, len(feat_data), lap_matrix, args.n_layers)\n",
    "    adjs_full = package_mxl(adjs_full, device)\n",
    "\n",
    "    optimizer = optim.Adam(susage.parameters()) if args.use_SGD=='False' else optim.SGD(susage.parameters(), lr=0.7)\n",
    "\n",
    "    best_model = copy.deepcopy(susage)\n",
    "    best_val_loss, _ = susage.calculate_loss_grad(feat_data, adjs_full, labels, valid_nodes)\n",
    "    cnt = 0\n",
    "    \n",
    "    wall_clock_time = [0]\n",
    "    loss_train = [best_val_loss]\n",
    "    loss_test = [best_val_loss]\n",
    "    grad_variance_all = []\n",
    "    loss_train_all = [best_val_loss]\n",
    "\n",
    "    for epoch in np.arange(args.epoch_num):\n",
    "        # create large batch\n",
    "        large_batch_sample_time_start = time.perf_counter()\n",
    "        large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes =  sample_large_batch(args, train_nodes, \n",
    "                                                                                                     large_samp_num_list, len(feat_data),\n",
    "                                                                                                     lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        large_batch_sample_time = time.perf_counter() - large_batch_sample_time_start\n",
    "        \n",
    "        large_batch_transfer_time_start = time.perf_counter()\n",
    "        large_adjs = package_mxl(large_adjs, device)\n",
    "        large_batch_transfer_time = time.perf_counter() - large_batch_transfer_time_start\n",
    "        \n",
    "        memory_allocated += [torch.cuda.memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        max_memory_allocated += [torch.cuda.max_memory_allocated(device)/1024/1024 if args.cuda != -1 else 0]\n",
    "        \n",
    "        # prepare next epoch train data\n",
    "        mini_batch_sample_time_start = time.perf_counter()\n",
    "        start_time = time.time()\n",
    "        pool = mp.Pool(args.pool_num)\n",
    "        jobs = prepare_data(pool, sampler, process_ids, large_output_nodes, samp_num_list, len(feat_data),\n",
    "                            lap_matrix, lap_matrix_sq, args.n_layers)\n",
    "        # fetch train data\n",
    "        train_data = [job.get() for job in jobs]\n",
    "        pool.close()\n",
    "        pool.join()\n",
    "        mini_batch_sample_time = time.perf_counter() - mini_batch_sample_time_start\n",
    "\n",
    "        inner_loop_num = args.batch_num\n",
    "        calculate_grad_vars = calculate_grad_vars and epoch<20\n",
    "        cur_train_loss, cur_train_loss_all, grad_variance, time_counter = sgcn_first(susage, optimizer, feat_data, labels,\n",
    "                                                         train_nodes, valid_nodes,\n",
    "                                                         large_adjs, large_input_nodes, large_output_nodes, large_sampled_nodes,\n",
    "                                                         train_data, inner_loop_num, device, dist_bound=dist_bound,\n",
    "                                                         calculate_grad_vars=calculate_grad_vars)\n",
    "        compute_time = time_counter['compute_time']\n",
    "        transfer_time = time_counter['transfer_time']\n",
    "        \n",
    "        epoch_time_counter = {\n",
    "            'large_batch_sample_time': large_batch_sample_time,\n",
    "            'large_batch_transfer_time': large_batch_transfer_time,\n",
    "            'mini_batch_sample_time': mini_batch_sample_time,\n",
    "            'compute_time': time_counter['compute_time'],\n",
    "            'mini_batch_transfer_time': time_counter['transfer_time'],\n",
    "        }\n",
    "        wall_clock_time.append(epoch_time_counter)        \n",
    "        loss_train_all.extend(cur_train_loss_all)\n",
    "        grad_variance_all.extend(grad_variance)\n",
    "        # calculate test loss\n",
    "        susage.eval()\n",
    "\n",
    "        susage.zero_grad()\n",
    "        val_loss, _ = susage.calculate_loss_grad(\n",
    "            feat_data, adjs_full, labels, valid_nodes)\n",
    "\n",
    "        if val_loss < best_val_loss:\n",
    "            best_model = copy.deepcopy(susage)\n",
    "            best_val_loss = val_loss\n",
    "            cnt = 0\n",
    "        else:\n",
    "            cnt += 1\n",
    "            \n",
    "        if cnt == args.n_stops//args.batch_num:\n",
    "            break\n",
    "\n",
    "        cur_test_loss = val_loss\n",
    "\n",
    "        loss_train.append(cur_train_loss)\n",
    "        loss_test.append(cur_test_loss)\n",
    "        \n",
    "        # print progress\n",
    "        print('Epoch: ', epoch,\n",
    "              '| train loss: %.8f' % cur_train_loss,\n",
    "              '| test loss: %.8f' % cur_test_loss)\n",
    "\n",
    "    f1_score_test = best_model.calculate_f1(feat_data, adjs_full, labels, test_nodes)\n",
    "    print('f1_score_test', f1_score_test)\n",
    "    return best_model, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = './results/%s_%s_results.pkl'%(args.save_prefix, args.dataset)\n",
    "if not os.path.exists(fn):\n",
    "    results = dict()\n",
    "else:\n",
    "    with open(fn, 'rb') as f:\n",
    "        results = pkl.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ###########################################################################################\n",
    "# ########################################### Full ##########################################\n",
    "# ###########################################################################################\n",
    "if 'full' in sample_method_list:\n",
    "    st = time.time()\n",
    "    print('>>> Full')\n",
    "    susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "                feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=True)\n",
    "    results['fullgcn'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "    print('fullgcn', time.time() - st)\n",
    "\n",
    "    with open(fn, 'wb') as f:\n",
    "        pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ###########################################################################################\n",
    "# ########################################### LADIES ########################################\n",
    "# ###########################################################################################\n",
    "if 'ladies' in sample_method_list:\n",
    "    sampler = ladies_sampler\n",
    "    samp_num = 512\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "    large_samp_num = samp_num*10\n",
    "    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])\n",
    "\n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> ladies_doubly')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['ladies_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('ladies_doubly', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if vanilla:\n",
    "        st = time.time()\n",
    "        print('>>> ladies')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)\n",
    "        results['ladies'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('ladies', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if zeroth_order:\n",
    "        st = time.time()\n",
    "        print('>>> ladies_zeroth')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['ladies_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('ladies_zeroth', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "########################################### FastGCN #######################################\n",
    "###########################################################################################\n",
    "if 'fastgcn' in sample_method_list:\n",
    "    sampler = fastgcn_sampler\n",
    "    samp_num = 4096\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "    large_samp_num = samp_num*10\n",
    "    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])\n",
    "    \n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> fastgcn_doubly')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['fastgcn_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('fastgcn_doubly', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if vanilla:\n",
    "        st = time.time()\n",
    "        print('>>> fastgcn')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)\n",
    "        results['fastgcn'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('fastgcn', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if zeroth_order:\n",
    "        st = time.time()\n",
    "        print('>>> fastgcn_zeroth')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['fastgcn_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('fastgcn_zeroth', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "########################################### GraphSaint ####################################\n",
    "###########################################################################################\n",
    "if 'graphsaint' in sample_method_list:\n",
    "    sampler = graphsaint_sampler\n",
    "    samp_num = 2048\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "    large_samp_num = samp_num*10\n",
    "    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])\n",
    "\n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> graphsaint_doubly')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['graphsaint_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsaint_doubly', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if vanilla:      \n",
    "        st = time.time()\n",
    "        print('>>> graphsaint')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)\n",
    "        results['graphsaint'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsaint', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if zeroth_order:\n",
    "        st = time.time()\n",
    "        print('>>> graphsaint_zeroth')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['graphsaint_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsaint_zeroth', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "########################################### GraphSage #####################################\n",
    "###########################################################################################\n",
    "if 'graphsage' in sample_method_list:\n",
    "    sampler = graphsage_sampler\n",
    "    samp_num = 5\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "    large_samp_num = samp_num*10\n",
    "    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])\n",
    "\n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> graphsage_doubly')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_doubly(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['graphsage_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsage_doubly', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if vanilla:\n",
    "        st = time.time()\n",
    "        print('>>> graphsage')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)\n",
    "        results['graphsage'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsage', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if zeroth_order:\n",
    "        st = time.time()\n",
    "        print('>>> graphsage_zeroth')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_zeroth(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['graphsage_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('graphsage_zeroth', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "########################################### Exact #########################################\n",
    "###########################################################################################\n",
    "if 'exact' in sample_method_list:\n",
    "    sampler = exact_sampler\n",
    "    samp_num = 0\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "    large_samp_num = samp_num*10\n",
    "    large_samp_num_list = np.array([large_samp_num for _ in range(args.n_layers)])\n",
    "    \n",
    "    if vanilla:\n",
    "        st = time.time()\n",
    "        print('>>> exact')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time  = sgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars, full_batch=False)\n",
    "        results['exact'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('exact', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> exact_first')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_sgcn_first(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['exact_first'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('exact_first', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################################################################\n",
    "########################################### VRGCN #########################################\n",
    "###########################################################################################\n",
    "if 'vrgcn' in sample_method_list:\n",
    "    samp_num = 2\n",
    "    sampler = vrgcn_sampler\n",
    "    samp_num_list = np.array([samp_num for _ in range(args.n_layers)])\n",
    "\n",
    "    if doubly_order:\n",
    "        st = time.time()\n",
    "        print('>>> vrgcn_doubly')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_vrgcn_doubly(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['vrgcn_doubly'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('vrgcn_doubly', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)\n",
    "\n",
    "    if vanilla:\n",
    "        st = time.time()\n",
    "        print('>>> vrgcn_zeroth')\n",
    "        susage, loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time = algorithm_vrgcn(\n",
    "            feat_data, labels, lap_matrix, train_nodes, valid_nodes, test_nodes,  args, device, calculate_grad_vars)\n",
    "        results['vrgcn_zeroth'] = [loss_train, loss_test, loss_train_all, f1_score_test, grad_variance_all, memory_allocated, max_memory_allocated, wall_clock_time]\n",
    "        print('vrgcn_zeroth', time.time() - st)\n",
    "\n",
    "        with open(fn, 'wb') as f:\n",
    "            pkl.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
