{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1bae0dab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import os\n",
    "import dgl\n",
    "from dgl import function as fn\n",
    "from dataset import load_graph_dataset\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "from gcn_with_node_flipping import gcn_with_node_flipping\n",
    "import tensorflow.compat.v1 as tf\n",
    "from graph_neural_networks import SGC_layer1, SGC_layer2\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from scipy.special import softmax, log_softmax\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from tqdm import tqdm\n",
    "import cupy as cp\n",
    "from random import choice\n",
    "import heapq\n",
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fa94b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataname = 'cora'\n",
    "l2_regularlization_term = 0.01\n",
    "num_layer = 2\n",
    "\n",
    "# dataname = 'pubmed'\n",
    "# l2_regularlization_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "# dataname = 'citeseer'\n",
    "# l2_regularlization_term = 0.003\n",
    "# num_layer = 2\n",
    "\n",
    "\"\"\"set up random seed, perturb ratio\"\"\"\n",
    "perturb_ratio_list = [0.05, 0.1, 0.15, 0.2]\n",
    "some_seed_list = [1, 11, 15, 42, 100]\n",
    "num_times_running = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ff07dcd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_to_mask(index, size):\n",
    "    mask = torch.zeros(size, dtype=torch.bool, device=index.device)\n",
    "    mask[index] = 1\n",
    "    return mask\n",
    "def random_splits_label_flip_attack(graph, labels, num_classes, seed):\n",
    "    # Set new random planetoid splits:\n",
    "    # * 20 * num_classes labels for training\n",
    "    # * 500 labels for validation\n",
    "    # * 1000 labels for testing\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    indices = []\n",
    "\n",
    "    for i in range(num_classes):\n",
    "        index = (labels == i).nonzero().view(-1)\n",
    "        index = index[torch.randperm(index.size(0))]\n",
    "        indices.append(index)\n",
    "\n",
    "    train_index = torch.cat([i[:20] for i in indices], dim=0)\n",
    "\n",
    "    rest_index = torch.cat([i[20:] for i in indices], dim=0)\n",
    "    rest_index = rest_index[torch.randperm(rest_index.size(0))]\n",
    "    \n",
    "    train_mask = index_to_mask(train_index, size=graph.num_nodes())\n",
    "    val_mask = index_to_mask(rest_index[:500], size=graph.num_nodes())\n",
    "    test_mask = index_to_mask(rest_index[500:1500], size=graph.num_nodes())\n",
    "\n",
    "    return train_mask, val_mask, test_mask\n",
    "\n",
    "def get_first_two_frequent(labels):\n",
    "    class_counts = np.bincount(labels)\n",
    "    a = np.argsort(class_counts)[-1]\n",
    "    b = np.argsort(class_counts)[-2]\n",
    "    return a, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ca628b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask):\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        for to_index_e in to_index_list:\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2b7893f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def node_flipping_attack_rev(dataname = dataname, l2_regularlization_term = 0.01, perturb_ratio = 0.05, \n",
    "                             num_layer = 2, some_seed = 42):\n",
    "    \n",
    "\n",
    "    graph, feat, labels, _, _, _, number_classes = load_graph_dataset(dataname)\n",
    "    train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, labels, number_classes, seed=some_seed)\n",
    "\n",
    "    lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "    feat0 = feat.clone()\n",
    "    degs = graph.in_degrees().float().clamp(min = 1)\n",
    "    norm = torch.pow(degs, -0.5)\n",
    "    norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "    for _ in range(num_layer):\n",
    "        feat0 = feat0 * norm\n",
    "        graph.ndata['h'] = feat0\n",
    "        graph.update_all(fn.copy_u('h', 'm'),\n",
    "                         fn.sum('m', 'h'))\n",
    "        feat0 = graph.ndata.pop('h')\n",
    "        feat0 = feat0 * norm\n",
    "\n",
    "    train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "    train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "    val_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "    val_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "    train_node_idx = torch.where(train_mask == 1)[0]\n",
    "\n",
    "    enc = OneHotEncoder(handle_unknown='ignore')\n",
    "    enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "    one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "    one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "    \"\"\" Train Logistic Regression \"\"\"\n",
    "    \n",
    "    lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "    lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "    logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "    # numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "    # # set l2_reg to False, verify the correctness of calculations\n",
    "    # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "    val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                        logits_val_y_origin,\n",
    "                                                                        one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "    hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "    loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "    loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "    del hess\n",
    "\n",
    "    acctual_influence_1 = []\n",
    "    acctual_influence_2 = []\n",
    "\n",
    "    predict_influence_1 = []\n",
    "    predict_influence_2 = []\n",
    "\n",
    "    for k in tqdm(range(len(train_node_idx))):\n",
    "\n",
    "        node_id = train_node_idx.numpy()[k]\n",
    "        nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "\n",
    "        # 2, remove the edges, calculate the perturbated feature\n",
    "        nis.remove_edges_sgc()\n",
    "        feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "        extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "\n",
    "\n",
    "        extra_index_train = torch.tensor(\n",
    "            [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "\n",
    "        extra_index_train_in_train = [\n",
    "            np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "\n",
    "        # 1, we need to remove the changed node feature from the perturbated feature, \n",
    "        # let it not added to the original feature\n",
    "\n",
    "\n",
    "        \"\"\"modified node features\"\"\"\n",
    "        extra_index_train_remove_node = extra_index_train.copy()\n",
    "        relative_node_id = np.where(extra_index_train_remove_node == node_id)[0]\n",
    "        extra_index_train_remove_node = np.delete(extra_index_train_remove_node, relative_node_id)\n",
    "        feat_to_be_added = feat_removed1[extra_index_train_remove_node].numpy()\n",
    "\n",
    "        \"\"\"index corresponding to modified node features\"\"\"\n",
    "        perturb_index = extra_index_train_in_train\n",
    "        added_index = perturb_index.copy()\n",
    "        added_index.remove(k)\n",
    "\n",
    "\n",
    "\n",
    "        train_x_new = feat_to_be_added\n",
    "        train_y_new = train_y[added_index]\n",
    "\n",
    "        train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "        train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "\n",
    "        one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "        logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "        train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                                logits_train_y_origin_0, \n",
    "                                                one_hot_labels_train_0, l2_reg = True)\n",
    "\n",
    "\n",
    "\n",
    "        pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "        \n",
    "        weight_3 = np.ones(len(train_x_orig))\n",
    "        weight_3[perturb_index] = 0 # 1...0...11\n",
    "\n",
    "        lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "        train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "        train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "\n",
    "\n",
    "        lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "\n",
    "\n",
    "        logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "        new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "        predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "        acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)\n",
    "\n",
    "\n",
    "    df = pd.DataFrame([predict_influence_1, acctual_influence_1]).T\n",
    "    df.columns = ['predicted influence', 'actual influence']\n",
    "    df.to_csv('result_flip_attack/'+ dataname + '/'+ '_pred_infl' +'.csv', index = False)\n",
    "    predicted_influence = np.array(predict_influence_1)\n",
    "\n",
    "    a, b = get_first_two_frequent(labels[test_mask])\n",
    "    idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "    idx2 = np.where(labels[train_mask].numpy() == b)[0]\n",
    "\n",
    "\n",
    "\n",
    "    num_train = np.sum(train_mask.numpy() == 1)\n",
    "    num_perturb = int(perturb_ratio * len(idx2) * 2)\n",
    "    predicted_influence_combined = np.concatenate([predicted_influence[idx1], predicted_influence[idx2]])\n",
    "    \n",
    "    predicted_influence_combined_sorted = np.sort(predicted_influence_combined)[::-1]\n",
    "\n",
    "    threshold = predicted_influence_combined_sorted[num_perturb]\n",
    "\n",
    "    \n",
    "    threshold_a = np.sort(predicted_influence[idx1])[::-1][int(num_perturb / 2)]\n",
    "    threshold_b = np.sort(predicted_influence[idx2])[::-1][int(num_perturb / 2)]\n",
    "    \n",
    "    \n",
    "    perturbed_train_y = train_y.copy()\n",
    "    new_labels_a = np.repeat(a, len(idx1))\n",
    "    new_labels_b = np.repeat(b, len(idx2))\n",
    "    assert(len(idx1) == len(idx2))\n",
    "\n",
    "    predicted_influence_combined_sorted\n",
    "\n",
    "    idx_a_to_b = np.where(predicted_influence[idx1] > threshold_a)[0]\n",
    "    \n",
    "    idx_b_to_a = np.where(predicted_influence[idx2] > threshold_b)[0]\n",
    "    \n",
    "\n",
    "    new_labels_a[idx_a_to_b] = b\n",
    "    new_labels_b[idx_b_to_a] = a\n",
    "\n",
    "    perturbed_train_y[idx1] = new_labels_a\n",
    "    perturbed_train_y[idx2] = new_labels_b\n",
    "\n",
    "    new_labels = labels.numpy().copy()\n",
    "    new_labels[train_mask] = perturbed_train_y\n",
    "    new_labels = torch.tensor(new_labels)\n",
    "\n",
    "    gcn_with_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=new_labels, \n",
    "                                                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                               num_classes=number_classes,  dropout=0.5)\n",
    "\n",
    "    gcn_without_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=labels, \n",
    "                                                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                               num_classes=number_classes)\n",
    "\n",
    "    acc_flip = gcn_with_node_flip.train_evaluate()\n",
    "    acc_no_flip = gcn_without_node_flip.train_evaluate()\n",
    "    \n",
    "    return acc_flip, acc_no_flip, predicted_influence_combined"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0a008754",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 4])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sort([4, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2f46275c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataname = 'cora'\n",
    "# l2_regularlization_term = 0.01\n",
    "# num_layer = 2\n",
    "\n",
    "# # dataname = 'pubmed'\n",
    "# # l2_regularlization_term = 0.004\n",
    "# # num_layer = 2\n",
    "\n",
    "# # dataname = 'citeseer'\n",
    "# # l2_regularlization_term = 0.003\n",
    "# # num_layer = 2\n",
    "\n",
    "\n",
    "# perturb_ratio_list = [0.1, 0.15, 0.2]\n",
    "# # some_seed_list = [1, 11, 15, 42, 100]\n",
    "# some_seed_list = [15]\n",
    "# num_times_running = 1\n",
    "\n",
    "# for perturb_ratio in perturb_ratio_list:\n",
    "#     temp_acc_flip_list = []\n",
    "#     temp_acc_no_flip_list = []\n",
    "#     for s in some_seed_list:\n",
    "        \n",
    "#         for _ in range(num_times_running):\n",
    "#             acc_flip, acc_no_flip, p_comb = node_flipping_attack_rev(dataname = dataname, l2_regularlization_term = l2_regularlization_term, \n",
    "#                                                              perturb_ratio = perturb_ratio, num_layer = num_layer, \n",
    "#                                                              some_seed = s)\n",
    "\n",
    "#             temp_acc_flip_list.append(acc_flip)\n",
    "#             temp_acc_no_flip_list.append(acc_no_flip)\n",
    "    \n",
    "#     flip_df = pd.DataFrame([temp_acc_flip_list, temp_acc_no_flip_list]).T\n",
    "#     flip_df.columns = ['filped accuracy', 'original accuracy']\n",
    "#     flip_df.to_csv('result_flip_attack/'+ dataname + '/'+ 'new_perturb_ratio_' + str(perturb_ratio)+'.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e6e772ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.sort(p_comb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b6369c19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mean acc. of flipped acc at ratio 0.05 is: 0.625\n",
      "The mean acc. of origin is: 0.8001999999999999\n",
      "The mean acc. of flipped acc at ratio 0.1 is: 0.578\n",
      "The mean acc. of origin is: 0.8001999999999999\n",
      "The mean acc. of flipped acc at ratio 0.15 is: 0.5678\n",
      "The mean acc. of origin is: 0.8001999999999999\n",
      "The mean acc. of flipped acc at ratio 0.20 is: 0.5257999999999999\n",
      "The mean acc. of origin is: 0.8001999999999999\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "50784412",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mean acc. of flipped acc at ratio 0.1 is: 0.753\n",
      "The mean acc. of origin is: 0.809\n",
      "The mean acc. of flipped acc at ratio 0.15 is: 0.735\n",
      "The mean acc. of origin is: 0.809\n",
      "The mean acc. of flipped acc at ratio 0.20 is: 0.712\n",
      "The mean acc. of origin is: 0.809\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a8b407b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mean acc. of flipped acc at ratio 0.05 is: 0.676\n",
      "The mean acc. of origin is: 0.7768\n",
      "The mean acc. of flipped acc at ratio 0.1 is: 0.617\n",
      "The mean acc. of origin is: 0.7768\n",
      "The mean acc. of flipped acc at ratio 0.15 is: 0.5592\n",
      "The mean acc. of origin is: 0.7768\n",
      "The mean acc. of flipped acc at ratio 0.20 is: 0.5110000000000001\n",
      "The mean acc. of origin is: 0.7768\n"
     ]
    }
   ],
   "source": [
    "df_pubmed_acc_005 = pd.read_csv('result_flip_attack/pubmed/perturb_ratio_0.05.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.05 is:', np.mean(df_pubmed_acc_005['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_pubmed_acc_005['original accuracy'].values))\n",
    "\n",
    "df_pubmed_acc_010 = pd.read_csv('result_flip_attack/pubmed/perturb_ratio_0.1.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.1 is:', np.mean(df_pubmed_acc_010['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_pubmed_acc_010['original accuracy'].values))\n",
    "\n",
    "df_pubmed_acc_015 = pd.read_csv('result_flip_attack/pubmed/perturb_ratio_0.15.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.15 is:', np.mean(df_pubmed_acc_015['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_pubmed_acc_015['original accuracy'].values))\n",
    "\n",
    "df_pubmed_acc_020 = pd.read_csv('result_flip_attack/pubmed/perturb_ratio_0.2.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.20 is:', np.mean(df_pubmed_acc_020['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_pubmed_acc_020['original accuracy'].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "63299cac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mean acc. of flipped acc at ratio 0.05 is: 0.5781999999999999\n",
      "The mean acc. of origin is: 0.662\n",
      "The mean acc. of flipped acc at ratio 0.1 is: 0.5218\n",
      "The mean acc. of origin is: 0.662\n",
      "The mean acc. of flipped acc at ratio 0.15 is: 0.4704\n",
      "The mean acc. of origin is: 0.662\n",
      "The mean acc. of flipped acc at ratio 0.20 is: 0.4464\n",
      "The mean acc. of origin is: 0.662\n"
     ]
    }
   ],
   "source": [
    "df_citeseer_acc_005 = pd.read_csv('result_flip_attack/citeseer/perturb_ratio_0.05.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.05 is:', np.mean(df_citeseer_acc_005['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_citeseer_acc_005['original accuracy'].values))\n",
    "\n",
    "df_citeseer_acc_010 = pd.read_csv('result_flip_attack/citeseer/perturb_ratio_0.1.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.1 is:', np.mean(df_citeseer_acc_010['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_citeseer_acc_010['original accuracy'].values))\n",
    "\n",
    "df_citeseer_acc_015 = pd.read_csv('result_flip_attack/citeseer/perturb_ratio_0.15.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.15 is:', np.mean(df_citeseer_acc_015['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_citeseer_acc_015['original accuracy'].values))\n",
    "\n",
    "df_citeseer_acc_020 = pd.read_csv('result_flip_attack/citeseer/perturb_ratio_0.2.csv')\n",
    "print('The mean acc. of flipped acc at ratio 0.20 is:', np.mean(df_citeseer_acc_020['filped accuracy'].values))\n",
    "print('The mean acc. of origin is:', np.mean(df_citeseer_acc_020['original accuracy'].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a9913cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_ratio = 0.20\n",
    "temp_acc_flip_list = []\n",
    "temp_acc_no_flip_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2de5ffde",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 139/139 [00:00<00:00, 210.61it/s]\n",
      "100%|█████████████████████████████████████████| 140/140 [00:10<00:00, 13.52it/s]\n",
      "/home/zizhang/anaconda3/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3419: RuntimeWarning: Mean of empty slice.\n",
      "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
      "/home/zizhang/anaconda3/lib/python3.8/site-packages/numpy/core/_methods.py:188: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00000 | Time(s) nan | Loss 1.9470 | Accuracy 0.1660 | number of edges 13264.00\n",
      "Epoch 00001 | Time(s) nan | Loss 1.9431 | Accuracy 0.1540 | number of edges 13264.00\n",
      "Epoch 00002 | Time(s) nan | Loss 1.9384 | Accuracy 0.1580 | number of edges 13264.00\n",
      "Epoch 00003 | Time(s) 0.0046 | Loss 1.9317 | Accuracy 0.1640 | number of edges 13264.00\n",
      "Epoch 00004 | Time(s) 0.0045 | Loss 1.9250 | Accuracy 0.1660 | number of edges 13264.00\n",
      "Epoch 00005 | Time(s) 0.0046 | Loss 1.9174 | Accuracy 0.1840 | number of edges 13264.00\n",
      "Epoch 00006 | Time(s) 0.0045 | Loss 1.9095 | Accuracy 0.1920 | number of edges 13264.00\n",
      "Epoch 00007 | Time(s) 0.0046 | Loss 1.9013 | Accuracy 0.1980 | number of edges 13264.00\n",
      "Epoch 00008 | Time(s) 0.0045 | Loss 1.8927 | Accuracy 0.2020 | number of edges 13264.00\n",
      "Epoch 00009 | Time(s) 0.0044 | Loss 1.8834 | Accuracy 0.2100 | number of edges 13264.00\n",
      "Epoch 00010 | Time(s) 0.0044 | Loss 1.8735 | Accuracy 0.2140 | number of edges 13264.00\n",
      "Epoch 00011 | Time(s) 0.0045 | Loss 1.8630 | Accuracy 0.2380 | number of edges 13264.00\n",
      "Epoch 00012 | Time(s) 0.0047 | Loss 1.8520 | Accuracy 0.2500 | number of edges 13264.00\n",
      "Epoch 00013 | Time(s) 0.0047 | Loss 1.8407 | Accuracy 0.2700 | number of edges 13264.00\n",
      "Epoch 00014 | Time(s) 0.0047 | Loss 1.8291 | Accuracy 0.2740 | number of edges 13264.00\n",
      "Epoch 00015 | Time(s) 0.0047 | Loss 1.8173 | Accuracy 0.2820 | number of edges 13264.00\n",
      "Epoch 00016 | Time(s) 0.0047 | Loss 1.8049 | Accuracy 0.2920 | number of edges 13264.00\n",
      "Epoch 00017 | Time(s) 0.0047 | Loss 1.7921 | Accuracy 0.3120 | number of edges 13264.00\n",
      "Epoch 00018 | Time(s) 0.0046 | Loss 1.7788 | Accuracy 0.3220 | number of edges 13264.00\n",
      "Epoch 00019 | Time(s) 0.0046 | Loss 1.7651 | Accuracy 0.3340 | number of edges 13264.00\n",
      "Epoch 00020 | Time(s) 0.0046 | Loss 1.7511 | Accuracy 0.3440 | number of edges 13264.00\n",
      "Epoch 00021 | Time(s) 0.0046 | Loss 1.7368 | Accuracy 0.3500 | number of edges 13264.00\n",
      "Epoch 00022 | Time(s) 0.0046 | Loss 1.7219 | Accuracy 0.3560 | number of edges 13264.00\n",
      "Epoch 00023 | Time(s) 0.0046 | Loss 1.7066 | Accuracy 0.3620 | number of edges 13264.00\n",
      "Epoch 00024 | Time(s) 0.0046 | Loss 1.6906 | Accuracy 0.3680 | number of edges 13264.00\n",
      "Epoch 00025 | Time(s) 0.0046 | Loss 1.6743 | Accuracy 0.3840 | number of edges 13264.00\n",
      "Epoch 00026 | Time(s) 0.0046 | Loss 1.6577 | Accuracy 0.4060 | number of edges 13264.00\n",
      "Epoch 00027 | Time(s) 0.0046 | Loss 1.6405 | Accuracy 0.4260 | number of edges 13264.00\n",
      "Epoch 00028 | Time(s) 0.0047 | Loss 1.6228 | Accuracy 0.4400 | number of edges 13264.00\n",
      "Epoch 00029 | Time(s) 0.0047 | Loss 1.6047 | Accuracy 0.4500 | number of edges 13264.00\n",
      "Epoch 00030 | Time(s) 0.0046 | Loss 1.5861 | Accuracy 0.4600 | number of edges 13264.00\n",
      "Epoch 00031 | Time(s) 0.0046 | Loss 1.5671 | Accuracy 0.4740 | number of edges 13264.00\n",
      "Epoch 00032 | Time(s) 0.0046 | Loss 1.5477 | Accuracy 0.4840 | number of edges 13264.00\n",
      "Epoch 00033 | Time(s) 0.0046 | Loss 1.5279 | Accuracy 0.4980 | number of edges 13264.00\n",
      "Epoch 00034 | Time(s) 0.0046 | Loss 1.5079 | Accuracy 0.5020 | number of edges 13264.00\n",
      "Epoch 00035 | Time(s) 0.0046 | Loss 1.4875 | Accuracy 0.5060 | number of edges 13264.00\n",
      "Epoch 00036 | Time(s) 0.0046 | Loss 1.4667 | Accuracy 0.5120 | number of edges 13264.00\n",
      "Epoch 00037 | Time(s) 0.0046 | Loss 1.4457 | Accuracy 0.5140 | number of edges 13264.00\n",
      "Epoch 00038 | Time(s) 0.0046 | Loss 1.4244 | Accuracy 0.5200 | number of edges 13264.00\n",
      "Epoch 00039 | Time(s) 0.0046 | Loss 1.4028 | Accuracy 0.5240 | number of edges 13264.00\n",
      "Epoch 00040 | Time(s) 0.0046 | Loss 1.3811 | Accuracy 0.5300 | number of edges 13264.00\n",
      "Epoch 00041 | Time(s) 0.0046 | Loss 1.3592 | Accuracy 0.5380 | number of edges 13264.00\n",
      "Epoch 00042 | Time(s) 0.0046 | Loss 1.3371 | Accuracy 0.5400 | number of edges 13264.00\n",
      "Epoch 00043 | Time(s) 0.0046 | Loss 1.3150 | Accuracy 0.5480 | number of edges 13264.00\n",
      "Epoch 00044 | Time(s) 0.0046 | Loss 1.2927 | Accuracy 0.5500 | number of edges 13264.00\n",
      "Epoch 00045 | Time(s) 0.0046 | Loss 1.2703 | Accuracy 0.5500 | number of edges 13264.00\n",
      "Epoch 00046 | Time(s) 0.0046 | Loss 1.2479 | Accuracy 0.5520 | number of edges 13264.00\n",
      "Epoch 00047 | Time(s) 0.0045 | Loss 1.2255 | Accuracy 0.5580 | number of edges 13264.00\n",
      "Epoch 00048 | Time(s) 0.0045 | Loss 1.2031 | Accuracy 0.5640 | number of edges 13264.00\n",
      "Epoch 00049 | Time(s) 0.0045 | Loss 1.1807 | Accuracy 0.5680 | number of edges 13264.00\n",
      "Epoch 00050 | Time(s) 0.0045 | Loss 1.1585 | Accuracy 0.5760 | number of edges 13264.00\n",
      "Epoch 00051 | Time(s) 0.0045 | Loss 1.1364 | Accuracy 0.5780 | number of edges 13264.00\n",
      "Epoch 00052 | Time(s) 0.0045 | Loss 1.1144 | Accuracy 0.5780 | number of edges 13264.00\n",
      "Epoch 00053 | Time(s) 0.0046 | Loss 1.0927 | Accuracy 0.5780 | number of edges 13264.00\n",
      "Epoch 00054 | Time(s) 0.0046 | Loss 1.0711 | Accuracy 0.5760 | number of edges 13264.00\n",
      "Epoch 00055 | Time(s) 0.0046 | Loss 1.0497 | Accuracy 0.5800 | number of edges 13264.00\n",
      "Epoch 00056 | Time(s) 0.0046 | Loss 1.0286 | Accuracy 0.5800 | number of edges 13264.00\n",
      "Epoch 00057 | Time(s) 0.0046 | Loss 1.0078 | Accuracy 0.5840 | number of edges 13264.00\n",
      "Epoch 00058 | Time(s) 0.0046 | Loss 0.9873 | Accuracy 0.5820 | number of edges 13264.00\n",
      "Epoch 00059 | Time(s) 0.0046 | Loss 0.9672 | Accuracy 0.5840 | number of edges 13264.00\n",
      "Epoch 00060 | Time(s) 0.0046 | Loss 0.9474 | Accuracy 0.5920 | number of edges 13264.00\n",
      "Epoch 00061 | Time(s) 0.0046 | Loss 0.9279 | Accuracy 0.5980 | number of edges 13264.00\n",
      "Epoch 00062 | Time(s) 0.0046 | Loss 0.9089 | Accuracy 0.5980 | number of edges 13264.00\n",
      "Epoch 00063 | Time(s) 0.0046 | Loss 0.8902 | Accuracy 0.5980 | number of edges 13264.00\n",
      "Epoch 00064 | Time(s) 0.0046 | Loss 0.8720 | Accuracy 0.5980 | number of edges 13264.00\n",
      "Epoch 00065 | Time(s) 0.0046 | Loss 0.8542 | Accuracy 0.6000 | number of edges 13264.00\n",
      "Epoch 00066 | Time(s) 0.0046 | Loss 0.8369 | Accuracy 0.5980 | number of edges 13264.00\n",
      "Epoch 00067 | Time(s) 0.0046 | Loss 0.8200 | Accuracy 0.6000 | number of edges 13264.00\n",
      "Epoch 00068 | Time(s) 0.0046 | Loss 0.8035 | Accuracy 0.6020 | number of edges 13264.00\n",
      "Epoch 00069 | Time(s) 0.0046 | Loss 0.7875 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00070 | Time(s) 0.0046 | Loss 0.7719 | Accuracy 0.6000 | number of edges 13264.00\n",
      "Epoch 00071 | Time(s) 0.0046 | Loss 0.7568 | Accuracy 0.6020 | number of edges 13264.00\n",
      "Epoch 00072 | Time(s) 0.0046 | Loss 0.7421 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00073 | Time(s) 0.0046 | Loss 0.7279 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00074 | Time(s) 0.0046 | Loss 0.7140 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00075 | Time(s) 0.0046 | Loss 0.7006 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00076 | Time(s) 0.0046 | Loss 0.6876 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00077 | Time(s) 0.0046 | Loss 0.6751 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00078 | Time(s) 0.0046 | Loss 0.6629 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00079 | Time(s) 0.0046 | Loss 0.6511 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00080 | Time(s) 0.0046 | Loss 0.6396 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00081 | Time(s) 0.0046 | Loss 0.6286 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00082 | Time(s) 0.0046 | Loss 0.6178 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00083 | Time(s) 0.0046 | Loss 0.6075 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00084 | Time(s) 0.0046 | Loss 0.5974 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00085 | Time(s) 0.0046 | Loss 0.5877 | Accuracy 0.6040 | number of edges 13264.00\n",
      "Epoch 00086 | Time(s) 0.0046 | Loss 0.5783 | Accuracy 0.6060 | number of edges 13264.00\n",
      "Epoch 00087 | Time(s) 0.0046 | Loss 0.5692 | Accuracy 0.6080 | number of edges 13264.00\n",
      "Epoch 00088 | Time(s) 0.0046 | Loss 0.5604 | Accuracy 0.6100 | number of edges 13264.00\n",
      "Epoch 00089 | Time(s) 0.0046 | Loss 0.5519 | Accuracy 0.6140 | number of edges 13264.00\n",
      "Epoch 00090 | Time(s) 0.0046 | Loss 0.5436 | Accuracy 0.6160 | number of edges 13264.00\n",
      "Epoch 00091 | Time(s) 0.0046 | Loss 0.5356 | Accuracy 0.6200 | number of edges 13264.00\n",
      "Epoch 00092 | Time(s) 0.0046 | Loss 0.5278 | Accuracy 0.6200 | number of edges 13264.00\n",
      "Epoch 00093 | Time(s) 0.0046 | Loss 0.5203 | Accuracy 0.6200 | number of edges 13264.00\n",
      "Epoch 00094 | Time(s) 0.0046 | Loss 0.5130 | Accuracy 0.6220 | number of edges 13264.00\n",
      "Epoch 00095 | Time(s) 0.0047 | Loss 0.5060 | Accuracy 0.6220 | number of edges 13264.00\n",
      "Epoch 00096 | Time(s) 0.0047 | Loss 0.4991 | Accuracy 0.6220 | number of edges 13264.00\n",
      "Epoch 00097 | Time(s) 0.0047 | Loss 0.4925 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00098 | Time(s) 0.0047 | Loss 0.4861 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00099 | Time(s) 0.0047 | Loss 0.4798 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00100 | Time(s) 0.0047 | Loss 0.4738 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00101 | Time(s) 0.0047 | Loss 0.4679 | Accuracy 0.6260 | number of edges 13264.00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00102 | Time(s) 0.0047 | Loss 0.4621 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00103 | Time(s) 0.0047 | Loss 0.4566 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00104 | Time(s) 0.0047 | Loss 0.4511 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00105 | Time(s) 0.0047 | Loss 0.4459 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00106 | Time(s) 0.0047 | Loss 0.4408 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00107 | Time(s) 0.0047 | Loss 0.4358 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00108 | Time(s) 0.0047 | Loss 0.4309 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00109 | Time(s) 0.0047 | Loss 0.4262 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00110 | Time(s) 0.0047 | Loss 0.4216 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00111 | Time(s) 0.0047 | Loss 0.4171 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00112 | Time(s) 0.0047 | Loss 0.4127 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00113 | Time(s) 0.0047 | Loss 0.4085 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00114 | Time(s) 0.0047 | Loss 0.4043 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00115 | Time(s) 0.0047 | Loss 0.4003 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00116 | Time(s) 0.0047 | Loss 0.3963 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00117 | Time(s) 0.0047 | Loss 0.3924 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00118 | Time(s) 0.0047 | Loss 0.3887 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00119 | Time(s) 0.0047 | Loss 0.3850 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00120 | Time(s) 0.0047 | Loss 0.3814 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00121 | Time(s) 0.0047 | Loss 0.3779 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00122 | Time(s) 0.0047 | Loss 0.3744 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00123 | Time(s) 0.0047 | Loss 0.3711 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00124 | Time(s) 0.0047 | Loss 0.3678 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00125 | Time(s) 0.0047 | Loss 0.3646 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00126 | Time(s) 0.0047 | Loss 0.3614 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00127 | Time(s) 0.0047 | Loss 0.3584 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00128 | Time(s) 0.0047 | Loss 0.3553 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00129 | Time(s) 0.0047 | Loss 0.3524 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00130 | Time(s) 0.0047 | Loss 0.3495 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00131 | Time(s) 0.0047 | Loss 0.3467 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00132 | Time(s) 0.0047 | Loss 0.3439 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00133 | Time(s) 0.0047 | Loss 0.3412 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00134 | Time(s) 0.0047 | Loss 0.3385 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00135 | Time(s) 0.0047 | Loss 0.3359 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00136 | Time(s) 0.0047 | Loss 0.3334 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00137 | Time(s) 0.0047 | Loss 0.3309 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00138 | Time(s) 0.0047 | Loss 0.3284 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00139 | Time(s) 0.0047 | Loss 0.3260 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00140 | Time(s) 0.0047 | Loss 0.3236 | Accuracy 0.6340 | number of edges 13264.00\n",
      "Epoch 00141 | Time(s) 0.0047 | Loss 0.3213 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00142 | Time(s) 0.0047 | Loss 0.3190 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00143 | Time(s) 0.0047 | Loss 0.3167 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00144 | Time(s) 0.0047 | Loss 0.3145 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00145 | Time(s) 0.0047 | Loss 0.3124 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00146 | Time(s) 0.0047 | Loss 0.3102 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00147 | Time(s) 0.0047 | Loss 0.3082 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00148 | Time(s) 0.0047 | Loss 0.3061 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00149 | Time(s) 0.0047 | Loss 0.3041 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00150 | Time(s) 0.0047 | Loss 0.3021 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00151 | Time(s) 0.0047 | Loss 0.3001 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00152 | Time(s) 0.0048 | Loss 0.2982 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00153 | Time(s) 0.0048 | Loss 0.2963 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00154 | Time(s) 0.0047 | Loss 0.2945 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00155 | Time(s) 0.0047 | Loss 0.2927 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00156 | Time(s) 0.0047 | Loss 0.2909 | Accuracy 0.6320 | number of edges 13264.00\n",
      "Epoch 00157 | Time(s) 0.0047 | Loss 0.2891 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00158 | Time(s) 0.0047 | Loss 0.2873 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00159 | Time(s) 0.0047 | Loss 0.2856 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00160 | Time(s) 0.0047 | Loss 0.2839 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00161 | Time(s) 0.0047 | Loss 0.2823 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00162 | Time(s) 0.0047 | Loss 0.2806 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00163 | Time(s) 0.0047 | Loss 0.2790 | Accuracy 0.6300 | number of edges 13264.00\n",
      "Epoch 00164 | Time(s) 0.0047 | Loss 0.2775 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00165 | Time(s) 0.0047 | Loss 0.2759 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00166 | Time(s) 0.0047 | Loss 0.2744 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00167 | Time(s) 0.0047 | Loss 0.2729 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00168 | Time(s) 0.0047 | Loss 0.2714 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00169 | Time(s) 0.0047 | Loss 0.2699 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00170 | Time(s) 0.0047 | Loss 0.2685 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00171 | Time(s) 0.0047 | Loss 0.2670 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00172 | Time(s) 0.0047 | Loss 0.2656 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00173 | Time(s) 0.0047 | Loss 0.2642 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00174 | Time(s) 0.0047 | Loss 0.2629 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00175 | Time(s) 0.0047 | Loss 0.2615 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00176 | Time(s) 0.0047 | Loss 0.2602 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00177 | Time(s) 0.0047 | Loss 0.2589 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00178 | Time(s) 0.0047 | Loss 0.2576 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00179 | Time(s) 0.0047 | Loss 0.2563 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00180 | Time(s) 0.0047 | Loss 0.2551 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00181 | Time(s) 0.0047 | Loss 0.2538 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00182 | Time(s) 0.0047 | Loss 0.2526 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00183 | Time(s) 0.0047 | Loss 0.2514 | Accuracy 0.6280 | number of edges 13264.00\n",
      "Epoch 00184 | Time(s) 0.0047 | Loss 0.2502 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00185 | Time(s) 0.0047 | Loss 0.2490 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00186 | Time(s) 0.0047 | Loss 0.2479 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00187 | Time(s) 0.0047 | Loss 0.2467 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00188 | Time(s) 0.0047 | Loss 0.2456 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00189 | Time(s) 0.0047 | Loss 0.2445 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00190 | Time(s) 0.0047 | Loss 0.2434 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00191 | Time(s) 0.0047 | Loss 0.2423 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00192 | Time(s) 0.0047 | Loss 0.2412 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00193 | Time(s) 0.0047 | Loss 0.2401 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00194 | Time(s) 0.0047 | Loss 0.2391 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00195 | Time(s) 0.0047 | Loss 0.2380 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00196 | Time(s) 0.0047 | Loss 0.2370 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00197 | Time(s) 0.0047 | Loss 0.2360 | Accuracy 0.6260 | number of edges 13264.00\n",
      "Epoch 00198 | Time(s) 0.0047 | Loss 0.2350 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00199 | Time(s) 0.0047 | Loss 0.2340 | Accuracy 0.6240 | number of edges 13264.00\n",
      "\n",
      "Test accuracy 62.70%\n",
      "Epoch 00000 | Time(s) nan | Loss 1.9461 | Accuracy 0.3060 | number of edges 13264.00\n",
      "Epoch 00001 | Time(s) nan | Loss 1.9387 | Accuracy 0.2820 | number of edges 13264.00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00002 | Time(s) nan | Loss 1.9306 | Accuracy 0.4700 | number of edges 13264.00\n",
      "Epoch 00003 | Time(s) 0.0051 | Loss 1.9208 | Accuracy 0.5120 | number of edges 13264.00\n",
      "Epoch 00004 | Time(s) 0.0051 | Loss 1.9104 | Accuracy 0.5640 | number of edges 13264.00\n",
      "Epoch 00005 | Time(s) 0.0049 | Loss 1.8992 | Accuracy 0.6240 | number of edges 13264.00\n",
      "Epoch 00006 | Time(s) 0.0046 | Loss 1.8875 | Accuracy 0.6800 | number of edges 13264.00\n",
      "Epoch 00007 | Time(s) 0.0048 | Loss 1.8751 | Accuracy 0.7180 | number of edges 13264.00\n",
      "Epoch 00008 | Time(s) 0.0049 | Loss 1.8618 | Accuracy 0.7400 | number of edges 13264.00\n",
      "Epoch 00009 | Time(s) 0.0049 | Loss 1.8476 | Accuracy 0.7400 | number of edges 13264.00\n",
      "Epoch 00010 | Time(s) 0.0049 | Loss 1.8328 | Accuracy 0.7500 | number of edges 13264.00\n",
      "Epoch 00011 | Time(s) 0.0049 | Loss 1.8174 | Accuracy 0.7440 | number of edges 13264.00\n",
      "Epoch 00012 | Time(s) 0.0050 | Loss 1.8014 | Accuracy 0.7440 | number of edges 13264.00\n",
      "Epoch 00013 | Time(s) 0.0050 | Loss 1.7848 | Accuracy 0.7460 | number of edges 13264.00\n",
      "Epoch 00014 | Time(s) 0.0050 | Loss 1.7676 | Accuracy 0.7400 | number of edges 13264.00\n",
      "Epoch 00015 | Time(s) 0.0051 | Loss 1.7499 | Accuracy 0.7500 | number of edges 13264.00\n",
      "Epoch 00016 | Time(s) 0.0051 | Loss 1.7316 | Accuracy 0.7620 | number of edges 13264.00\n",
      "Epoch 00017 | Time(s) 0.0052 | Loss 1.7128 | Accuracy 0.7760 | number of edges 13264.00\n",
      "Epoch 00018 | Time(s) 0.0052 | Loss 1.6933 | Accuracy 0.7860 | number of edges 13264.00\n",
      "Epoch 00019 | Time(s) 0.0052 | Loss 1.6734 | Accuracy 0.7940 | number of edges 13264.00\n",
      "Epoch 00020 | Time(s) 0.0052 | Loss 1.6530 | Accuracy 0.7940 | number of edges 13264.00\n",
      "Epoch 00021 | Time(s) 0.0052 | Loss 1.6321 | Accuracy 0.7920 | number of edges 13264.00\n",
      "Epoch 00022 | Time(s) 0.0053 | Loss 1.6106 | Accuracy 0.7900 | number of edges 13264.00\n",
      "Epoch 00023 | Time(s) 0.0053 | Loss 1.5887 | Accuracy 0.7860 | number of edges 13264.00\n",
      "Epoch 00024 | Time(s) 0.0053 | Loss 1.5663 | Accuracy 0.7800 | number of edges 13264.00\n",
      "Epoch 00025 | Time(s) 0.0054 | Loss 1.5434 | Accuracy 0.7760 | number of edges 13264.00\n",
      "Epoch 00026 | Time(s) 0.0054 | Loss 1.5201 | Accuracy 0.7780 | number of edges 13264.00\n",
      "Epoch 00027 | Time(s) 0.0054 | Loss 1.4963 | Accuracy 0.7780 | number of edges 13264.00\n",
      "Epoch 00028 | Time(s) 0.0054 | Loss 1.4722 | Accuracy 0.7820 | number of edges 13264.00\n",
      "Epoch 00029 | Time(s) 0.0054 | Loss 1.4477 | Accuracy 0.7860 | number of edges 13264.00\n",
      "Epoch 00030 | Time(s) 0.0054 | Loss 1.4229 | Accuracy 0.7880 | number of edges 13264.00\n",
      "Epoch 00031 | Time(s) 0.0054 | Loss 1.3978 | Accuracy 0.7900 | number of edges 13264.00\n",
      "Epoch 00032 | Time(s) 0.0054 | Loss 1.3724 | Accuracy 0.7940 | number of edges 13264.00\n",
      "Epoch 00033 | Time(s) 0.0054 | Loss 1.3468 | Accuracy 0.7940 | number of edges 13264.00\n",
      "Epoch 00034 | Time(s) 0.0054 | Loss 1.3210 | Accuracy 0.7960 | number of edges 13264.00\n",
      "Epoch 00035 | Time(s) 0.0054 | Loss 1.2952 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00036 | Time(s) 0.0054 | Loss 1.2692 | Accuracy 0.7940 | number of edges 13264.00\n",
      "Epoch 00037 | Time(s) 0.0054 | Loss 1.2432 | Accuracy 0.7900 | number of edges 13264.00\n",
      "Epoch 00038 | Time(s) 0.0053 | Loss 1.2173 | Accuracy 0.7920 | number of edges 13264.00\n",
      "Epoch 00039 | Time(s) 0.0053 | Loss 1.1914 | Accuracy 0.7960 | number of edges 13264.00\n",
      "Epoch 00040 | Time(s) 0.0053 | Loss 1.1657 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00041 | Time(s) 0.0053 | Loss 1.1401 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00042 | Time(s) 0.0053 | Loss 1.1148 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00043 | Time(s) 0.0053 | Loss 1.0897 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00044 | Time(s) 0.0053 | Loss 1.0650 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00045 | Time(s) 0.0053 | Loss 1.0406 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00046 | Time(s) 0.0053 | Loss 1.0166 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00047 | Time(s) 0.0053 | Loss 0.9931 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00048 | Time(s) 0.0053 | Loss 0.9700 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00049 | Time(s) 0.0053 | Loss 0.9474 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00050 | Time(s) 0.0053 | Loss 0.9253 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00051 | Time(s) 0.0053 | Loss 0.9038 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00052 | Time(s) 0.0053 | Loss 0.8827 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00053 | Time(s) 0.0053 | Loss 0.8623 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00054 | Time(s) 0.0053 | Loss 0.8424 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00055 | Time(s) 0.0053 | Loss 0.8231 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00056 | Time(s) 0.0052 | Loss 0.8043 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00057 | Time(s) 0.0053 | Loss 0.7862 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00058 | Time(s) 0.0053 | Loss 0.7686 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00059 | Time(s) 0.0053 | Loss 0.7516 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00060 | Time(s) 0.0053 | Loss 0.7351 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00061 | Time(s) 0.0053 | Loss 0.7192 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00062 | Time(s) 0.0053 | Loss 0.7038 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00063 | Time(s) 0.0053 | Loss 0.6890 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00064 | Time(s) 0.0053 | Loss 0.6747 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00065 | Time(s) 0.0053 | Loss 0.6609 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00066 | Time(s) 0.0053 | Loss 0.6476 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00067 | Time(s) 0.0052 | Loss 0.6348 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00068 | Time(s) 0.0052 | Loss 0.6225 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00069 | Time(s) 0.0052 | Loss 0.6106 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00070 | Time(s) 0.0052 | Loss 0.5991 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00071 | Time(s) 0.0052 | Loss 0.5880 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00072 | Time(s) 0.0052 | Loss 0.5773 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00073 | Time(s) 0.0052 | Loss 0.5670 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00074 | Time(s) 0.0052 | Loss 0.5571 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00075 | Time(s) 0.0052 | Loss 0.5475 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00076 | Time(s) 0.0052 | Loss 0.5383 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00077 | Time(s) 0.0052 | Loss 0.5294 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00078 | Time(s) 0.0052 | Loss 0.5208 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00079 | Time(s) 0.0051 | Loss 0.5125 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00080 | Time(s) 0.0051 | Loss 0.5045 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00081 | Time(s) 0.0051 | Loss 0.4968 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00082 | Time(s) 0.0051 | Loss 0.4893 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00083 | Time(s) 0.0051 | Loss 0.4821 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00084 | Time(s) 0.0051 | Loss 0.4751 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00085 | Time(s) 0.0051 | Loss 0.4683 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00086 | Time(s) 0.0051 | Loss 0.4617 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00087 | Time(s) 0.0051 | Loss 0.4554 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00088 | Time(s) 0.0051 | Loss 0.4493 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00089 | Time(s) 0.0051 | Loss 0.4433 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00090 | Time(s) 0.0051 | Loss 0.4375 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00091 | Time(s) 0.0050 | Loss 0.4319 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00092 | Time(s) 0.0050 | Loss 0.4265 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00093 | Time(s) 0.0050 | Loss 0.4212 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00094 | Time(s) 0.0050 | Loss 0.4161 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00095 | Time(s) 0.0050 | Loss 0.4111 | Accuracy 0.8120 | number of edges 13264.00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00096 | Time(s) 0.0050 | Loss 0.4063 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00097 | Time(s) 0.0050 | Loss 0.4016 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00098 | Time(s) 0.0050 | Loss 0.3971 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00099 | Time(s) 0.0050 | Loss 0.3926 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00100 | Time(s) 0.0050 | Loss 0.3883 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00101 | Time(s) 0.0050 | Loss 0.3841 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00102 | Time(s) 0.0050 | Loss 0.3799 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00103 | Time(s) 0.0050 | Loss 0.3759 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00104 | Time(s) 0.0050 | Loss 0.3720 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00105 | Time(s) 0.0050 | Loss 0.3683 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00106 | Time(s) 0.0050 | Loss 0.3645 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00107 | Time(s) 0.0050 | Loss 0.3609 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00108 | Time(s) 0.0050 | Loss 0.3574 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00109 | Time(s) 0.0050 | Loss 0.3540 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00110 | Time(s) 0.0050 | Loss 0.3506 | Accuracy 0.8200 | number of edges 13264.00\n",
      "Epoch 00111 | Time(s) 0.0050 | Loss 0.3474 | Accuracy 0.8200 | number of edges 13264.00\n",
      "Epoch 00112 | Time(s) 0.0050 | Loss 0.3442 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00113 | Time(s) 0.0050 | Loss 0.3410 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00114 | Time(s) 0.0050 | Loss 0.3380 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00115 | Time(s) 0.0049 | Loss 0.3350 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00116 | Time(s) 0.0049 | Loss 0.3320 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00117 | Time(s) 0.0049 | Loss 0.3292 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00118 | Time(s) 0.0049 | Loss 0.3264 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00119 | Time(s) 0.0049 | Loss 0.3236 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00120 | Time(s) 0.0049 | Loss 0.3209 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00121 | Time(s) 0.0049 | Loss 0.3183 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00122 | Time(s) 0.0049 | Loss 0.3158 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00123 | Time(s) 0.0049 | Loss 0.3132 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00124 | Time(s) 0.0049 | Loss 0.3108 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00125 | Time(s) 0.0049 | Loss 0.3083 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00126 | Time(s) 0.0049 | Loss 0.3060 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00127 | Time(s) 0.0049 | Loss 0.3036 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00128 | Time(s) 0.0049 | Loss 0.3014 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00129 | Time(s) 0.0049 | Loss 0.2991 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00130 | Time(s) 0.0049 | Loss 0.2969 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00131 | Time(s) 0.0049 | Loss 0.2948 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00132 | Time(s) 0.0049 | Loss 0.2927 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00133 | Time(s) 0.0049 | Loss 0.2906 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00134 | Time(s) 0.0049 | Loss 0.2886 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00135 | Time(s) 0.0049 | Loss 0.2866 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00136 | Time(s) 0.0049 | Loss 0.2846 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00137 | Time(s) 0.0049 | Loss 0.2827 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00138 | Time(s) 0.0049 | Loss 0.2808 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00139 | Time(s) 0.0049 | Loss 0.2789 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00140 | Time(s) 0.0049 | Loss 0.2771 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00141 | Time(s) 0.0049 | Loss 0.2753 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00142 | Time(s) 0.0049 | Loss 0.2735 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00143 | Time(s) 0.0049 | Loss 0.2717 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00144 | Time(s) 0.0049 | Loss 0.2700 | Accuracy 0.8160 | number of edges 13264.00\n",
      "Epoch 00145 | Time(s) 0.0049 | Loss 0.2684 | Accuracy 0.8180 | number of edges 13264.00\n",
      "Epoch 00146 | Time(s) 0.0049 | Loss 0.2667 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00147 | Time(s) 0.0048 | Loss 0.2651 | Accuracy 0.8140 | number of edges 13264.00\n",
      "Epoch 00148 | Time(s) 0.0049 | Loss 0.2635 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00149 | Time(s) 0.0048 | Loss 0.2619 | Accuracy 0.8120 | number of edges 13264.00\n",
      "Epoch 00150 | Time(s) 0.0048 | Loss 0.2603 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00151 | Time(s) 0.0048 | Loss 0.2588 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00152 | Time(s) 0.0048 | Loss 0.2573 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00153 | Time(s) 0.0048 | Loss 0.2558 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00154 | Time(s) 0.0048 | Loss 0.2544 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00155 | Time(s) 0.0048 | Loss 0.2529 | Accuracy 0.8100 | number of edges 13264.00\n",
      "Epoch 00156 | Time(s) 0.0048 | Loss 0.2515 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00157 | Time(s) 0.0048 | Loss 0.2501 | Accuracy 0.8080 | number of edges 13264.00\n",
      "Epoch 00158 | Time(s) 0.0048 | Loss 0.2488 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00159 | Time(s) 0.0048 | Loss 0.2474 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00160 | Time(s) 0.0048 | Loss 0.2461 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00161 | Time(s) 0.0048 | Loss 0.2448 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00162 | Time(s) 0.0048 | Loss 0.2435 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00163 | Time(s) 0.0048 | Loss 0.2422 | Accuracy 0.8060 | number of edges 13264.00\n",
      "Epoch 00164 | Time(s) 0.0048 | Loss 0.2409 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00165 | Time(s) 0.0048 | Loss 0.2397 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00166 | Time(s) 0.0048 | Loss 0.2385 | Accuracy 0.8040 | number of edges 13264.00\n",
      "Epoch 00167 | Time(s) 0.0048 | Loss 0.2373 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00168 | Time(s) 0.0048 | Loss 0.2361 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00169 | Time(s) 0.0048 | Loss 0.2350 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00170 | Time(s) 0.0048 | Loss 0.2338 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00171 | Time(s) 0.0048 | Loss 0.2327 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00172 | Time(s) 0.0048 | Loss 0.2315 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00173 | Time(s) 0.0048 | Loss 0.2304 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00174 | Time(s) 0.0048 | Loss 0.2293 | Accuracy 0.8020 | number of edges 13264.00\n",
      "Epoch 00175 | Time(s) 0.0048 | Loss 0.2283 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00176 | Time(s) 0.0048 | Loss 0.2272 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00177 | Time(s) 0.0048 | Loss 0.2262 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00178 | Time(s) 0.0048 | Loss 0.2251 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00179 | Time(s) 0.0048 | Loss 0.2241 | Accuracy 0.8000 | number of edges 13264.00\n",
      "Epoch 00180 | Time(s) 0.0048 | Loss 0.2231 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00181 | Time(s) 0.0048 | Loss 0.2221 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00182 | Time(s) 0.0048 | Loss 0.2211 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00183 | Time(s) 0.0048 | Loss 0.2201 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00184 | Time(s) 0.0048 | Loss 0.2192 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00185 | Time(s) 0.0048 | Loss 0.2182 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00186 | Time(s) 0.0048 | Loss 0.2173 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00187 | Time(s) 0.0048 | Loss 0.2164 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00188 | Time(s) 0.0048 | Loss 0.2155 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00189 | Time(s) 0.0048 | Loss 0.2146 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00190 | Time(s) 0.0048 | Loss 0.2137 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00191 | Time(s) 0.0048 | Loss 0.2128 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00192 | Time(s) 0.0048 | Loss 0.2120 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00193 | Time(s) 0.0048 | Loss 0.2111 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00194 | Time(s) 0.0048 | Loss 0.2103 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00195 | Time(s) 0.0048 | Loss 0.2094 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00196 | Time(s) 0.0048 | Loss 0.2086 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00197 | Time(s) 0.0048 | Loss 0.2078 | Accuracy 0.7980 | number of edges 13264.00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00198 | Time(s) 0.0048 | Loss 0.2070 | Accuracy 0.7980 | number of edges 13264.00\n",
      "Epoch 00199 | Time(s) 0.0048 | Loss 0.2062 | Accuracy 0.7980 | number of edges 13264.00\n",
      "\n",
      "Test accuracy 80.30%\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "graph, feat, labels, _, _, _, number_classes = load_graph_dataset(dataname)\n",
    "train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, labels, number_classes, seed=15)\n",
    "\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm\n",
    "    \n",
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))\n",
    "\n",
    "train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "                                             one_hot_labels_train, l2_reg=True)\n",
    "val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, \n",
    "                                                   one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "# hess = fast_hess_cuda(train_x, logits_train_y)\n",
    "hess = lr.hess_cuda(train_x, logits_train_y)\n",
    "\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=True)\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=False)\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "\n",
    "pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "pred_infl = list(pred_infl.reshape(-1))\n",
    "#\n",
    "num_train = len(train_x)\n",
    "act_infl = []\n",
    "\n",
    "for i in tqdm(range(num_train)):\n",
    "    lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "    train_x_new = np.delete(train_x, i, axis = 0)\n",
    "    train_y_new = np.delete(train_y, i)\n",
    "    lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "    logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "    new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test, l2_reg = True)\n",
    "    act_infl.append(new_ori_val_loss - ori_val_loss)\n",
    "    \n",
    "    \n",
    "df = pd.DataFrame([pred_infl, act_infl]).T\n",
    "df.columns = ['predicted influence', 'actual influence']\n",
    "predicted_influence = np.array(pred_infl)\n",
    "\n",
    "a, b = get_first_two_frequent(labels)\n",
    "idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "idx2 = np.where(labels[train_mask].numpy() == b)[0]\n",
    "\n",
    "\n",
    "\n",
    "num_train = np.sum(train_mask.numpy() == 1)\n",
    "num_perturb = int(perturb_ratio * 40)\n",
    "predicted_influence_combined = np.concatenate([predicted_influence[idx1], predicted_influence[idx2]])\n",
    "predicted_influence_combined_sorted = np.sort(predicted_influence_combined)[::-1]\n",
    "threshold = predicted_influence_combined_sorted[num_perturb]\n",
    "\n",
    "perturbed_train_y = train_y.copy()\n",
    "new_labels_a = np.repeat(a, len(idx1))\n",
    "new_labels_b = np.repeat(b, len(idx2))\n",
    "assert(len(idx1) == len(idx2))\n",
    "\n",
    "predicted_influence_combined_sorted\n",
    "\n",
    "idx_a_to_b = np.where(predicted_influence[idx1] > threshold)[0]\n",
    "idx_b_to_a = np.where(predicted_influence[idx2] > threshold)[0]\n",
    "\n",
    "new_labels_a[idx_a_to_b] = b\n",
    "new_labels_b[idx_b_to_a] = a\n",
    "\n",
    "perturbed_train_y[idx1] = new_labels_a\n",
    "perturbed_train_y[idx2] = new_labels_b\n",
    "\n",
    "new_labels = labels.numpy().copy()\n",
    "new_labels[train_mask] = perturbed_train_y\n",
    "new_labels = torch.tensor(new_labels)\n",
    "\n",
    "gcn_with_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=new_labels, \n",
    "                                            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                           num_classes=number_classes)\n",
    "\n",
    "gcn_without_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=labels, \n",
    "                                            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                           num_classes=number_classes)\n",
    "\n",
    "acc_flip = gcn_with_node_flip.train_evaluate()\n",
    "acc_no_flip = gcn_without_node_flip.train_evaluate()\n",
    "\n",
    "temp_acc_flip_list.append(acc_flip)\n",
    "temp_acc_no_flip_list.append(acc_no_flip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0e375aa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_acc_flip_list = []\n",
    "temp_acc_no_flip_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd463a9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a235738",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35741c83",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
