{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "318e880e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import os\n",
    "import dgl\n",
    "from dgl import function as fn\n",
    "from dataset import load_graph_dataset\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "from gcn_with_node_flipping import gcn_with_node_flipping\n",
    "import tensorflow.compat.v1 as tf\n",
    "from graph_neural_networks import SGC_layer1, SGC_layer2\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from scipy.special import softmax, log_softmax\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from tqdm import tqdm\n",
    "import cupy as cp\n",
    "from random import choice\n",
    "import heapq\n",
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "858e1062",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataname = 'cora'\n",
    "# l2_regularlization_term = 0.01\n",
    "# num_layer = 2\n",
    "\n",
    "# dataname = 'pubmed'\n",
    "# l2_regularlization_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "dataname = 'citeseer'\n",
    "l2_regularlization_term = 0.003\n",
    "num_layer = 2\n",
    "\n",
    "\"\"\"set up random seed, perturb ratio\"\"\"\n",
    "perturb_ratio_list = [0.05, 0.1, 0.15, 0.2]\n",
    "some_seed_list = [1, 11, 15, 42, 100]\n",
    "num_times_running = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "458f06fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_to_mask(index, size):\n",
    "    mask = torch.zeros(size, dtype=torch.bool, device=index.device)\n",
    "    mask[index] = 1\n",
    "    return mask\n",
    "def random_splits_label_flip_attack(graph, labels, num_classes, seed):\n",
    "    # Set new random planetoid splits:\n",
    "    # * 20 * num_classes labels for training\n",
    "    # * 500 labels for validation\n",
    "    # * 1000 labels for testing\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    indices = []\n",
    "\n",
    "    for i in range(num_classes):\n",
    "        index = (labels == i).nonzero().view(-1)\n",
    "        index = index[torch.randperm(index.size(0))]\n",
    "        indices.append(index)\n",
    "\n",
    "    train_index = torch.cat([i[:20] for i in indices], dim=0)\n",
    "\n",
    "    rest_index = torch.cat([i[20:] for i in indices], dim=0)\n",
    "    rest_index = rest_index[torch.randperm(rest_index.size(0))]\n",
    "    \n",
    "    train_mask = index_to_mask(train_index, size=graph.num_nodes())\n",
    "    val_mask = index_to_mask(rest_index[:500], size=graph.num_nodes())\n",
    "    test_mask = index_to_mask(rest_index[500:1500], size=graph.num_nodes())\n",
    "\n",
    "    return train_mask, val_mask, test_mask\n",
    "\n",
    "def get_first_two_frequent(labels):\n",
    "    class_counts = np.bincount(labels)\n",
    "    a = np.argsort(class_counts)[-1]\n",
    "    b = np.argsort(class_counts)[-2]\n",
    "    return a, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "edbc2db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask):\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        for to_index_e in to_index_list:\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc4b854a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def node_flipping_attack_rev(dataname = dataname, l2_regularlization_term = 0.01, perturb_ratio = 0.05, \n",
    "                             num_layer = 2):\n",
    "    \n",
    "\n",
    "#     graph, feat, labels, _, _, _, number_classes = load_graph_dataset(dataname)\n",
    "    train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, labels, number_classes, seed=some_seed)\n",
    "\n",
    "    lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "    feat0 = feat.clone()\n",
    "    degs = graph.in_degrees().float().clamp(min = 1)\n",
    "    norm = torch.pow(degs, -0.5)\n",
    "    norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "    for _ in range(num_layer):\n",
    "        feat0 = feat0 * norm\n",
    "        graph.ndata['h'] = feat0\n",
    "        graph.update_all(fn.copy_u('h', 'm'),\n",
    "                         fn.sum('m', 'h'))\n",
    "        feat0 = graph.ndata.pop('h')\n",
    "        feat0 = feat0 * norm\n",
    "\n",
    "    train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "    train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "    val_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "    val_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "    train_node_idx = torch.where(train_mask == 1)[0]\n",
    "\n",
    "    enc = OneHotEncoder(handle_unknown='ignore')\n",
    "    enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "    one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "    one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "    \"\"\" Train Logistic Regression \"\"\"\n",
    "    \n",
    "    lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "    lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "    logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "    # numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "    # # set l2_reg to False, verify the correctness of calculations\n",
    "    # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "    val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                        logits_val_y_origin,\n",
    "                                                                        one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "    hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "    loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "    loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "    del hess\n",
    "\n",
    "    acctual_influence_1 = []\n",
    "    acctual_influence_2 = []\n",
    "\n",
    "    predict_influence_1 = []\n",
    "    predict_influence_2 = []\n",
    "\n",
    "    for k in tqdm(range(len(train_node_idx))):\n",
    "\n",
    "        node_id = train_node_idx.numpy()[k]\n",
    "        nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "\n",
    "        # 2, remove the edges, calculate the perturbated feature\n",
    "        nis.remove_edges_sgc()\n",
    "        feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "        extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "\n",
    "\n",
    "        extra_index_train = torch.tensor(\n",
    "            [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "\n",
    "        extra_index_train_in_train = [\n",
    "            np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "\n",
    "        # 1, we need to remove the changed node feature from the perturbated feature, \n",
    "        # let it not added to the original feature\n",
    "\n",
    "\n",
    "        \"\"\"modified node features\"\"\"\n",
    "        extra_index_train_remove_node = extra_index_train.copy()\n",
    "        relative_node_id = np.where(extra_index_train_remove_node == node_id)[0]\n",
    "        extra_index_train_remove_node = np.delete(extra_index_train_remove_node, relative_node_id)\n",
    "        feat_to_be_added = feat_removed1[extra_index_train_remove_node].numpy()\n",
    "\n",
    "        \"\"\"index corresponding to modified node features\"\"\"\n",
    "        perturb_index = extra_index_train_in_train\n",
    "        added_index = perturb_index.copy()\n",
    "        added_index.remove(k)\n",
    "\n",
    "\n",
    "\n",
    "        train_x_new = feat_to_be_added\n",
    "        train_y_new = train_y[added_index]\n",
    "\n",
    "        train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "        train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "\n",
    "        one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "        logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "        train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                                logits_train_y_origin_0, \n",
    "                                                one_hot_labels_train_0, l2_reg = True)\n",
    "\n",
    "\n",
    "\n",
    "        pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "        \n",
    "        weight_3 = np.ones(len(train_x_orig))\n",
    "        weight_3[perturb_index] = 0 # 1...0...11\n",
    "\n",
    "        lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "        train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "        train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "\n",
    "\n",
    "        lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "\n",
    "\n",
    "        logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "        new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "        predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "        acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)\n",
    "\n",
    "\n",
    "    df = pd.DataFrame([predict_influence_1, acctual_influence_1]).T\n",
    "    df.columns = ['predicted influence', 'actual influence']\n",
    "    df.to_csv('result_flip_attack/'+ dataname + '/'+ '_pred_infl' +'.csv', index = False)\n",
    "    predicted_influence = np.array(predict_influence_1)\n",
    "\n",
    "    a, b = get_first_two_frequent(labels[test_mask])\n",
    "    idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "    idx2 = np.where(labels[train_mask].numpy() == b)[0]\n",
    "\n",
    "\n",
    "\n",
    "    num_train = np.sum(train_mask.numpy() == 1)\n",
    "    num_perturb = int(perturb_ratio * len(idx2) * 2)\n",
    "    predicted_influence_combined = np.concatenate([predicted_influence[idx1], predicted_influence[idx2]])\n",
    "    \n",
    "    predicted_influence_combined_sorted = np.sort(predicted_influence_combined)[::-1]\n",
    "\n",
    "    threshold = predicted_influence_combined_sorted[num_perturb]\n",
    "\n",
    "    \n",
    "    threshold_a = np.sort(predicted_influence[idx1])[::-1][int(num_perturb / 2)]\n",
    "    threshold_b = np.sort(predicted_influence[idx2])[::-1][int(num_perturb / 2)]\n",
    "    \n",
    "    \n",
    "    perturbed_train_y = train_y.copy()\n",
    "    new_labels_a = np.repeat(a, len(idx1))\n",
    "    new_labels_b = np.repeat(b, len(idx2))\n",
    "    assert(len(idx1) == len(idx2))\n",
    "\n",
    "    predicted_influence_combined_sorted\n",
    "\n",
    "    idx_a_to_b = np.where(predicted_influence[idx1] > threshold_a)[0]\n",
    "    \n",
    "    idx_b_to_a = np.where(predicted_influence[idx2] > threshold_b)[0]\n",
    "    \n",
    "\n",
    "    new_labels_a[idx_a_to_b] = b\n",
    "    new_labels_b[idx_b_to_a] = a\n",
    "\n",
    "    perturbed_train_y[idx1] = new_labels_a\n",
    "    perturbed_train_y[idx2] = new_labels_b\n",
    "\n",
    "    new_labels = labels.numpy().copy()\n",
    "    new_labels[train_mask] = perturbed_train_y\n",
    "    new_labels = torch.tensor(new_labels)\n",
    "\n",
    "    gcn_with_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=new_labels, \n",
    "                                                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                               num_classes=number_classes,  dropout=0.5)\n",
    "\n",
    "    gcn_without_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=labels, \n",
    "                                                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                               num_classes=number_classes)\n",
    "\n",
    "    acc_flip = gcn_with_node_flip.train_evaluate()\n",
    "    acc_no_flip = gcn_without_node_flip.train_evaluate()\n",
    "    \n",
    "    return acc_flip, acc_no_flip, predicted_influence_combined"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "23bc3742",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 4])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sort([4, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0b123565",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_ratio = 0.20\n",
    "temp_acc_flip_list = []\n",
    "temp_acc_no_flip_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "33a3c80b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)\n",
    "# # train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, labels, number_classes, seed=15)\n",
    "\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "# feat0 = feat.clone()\n",
    "# degs = graph.in_degrees().float().clamp(min = 1)\n",
    "# norm = torch.pow(degs, -0.5)\n",
    "# norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "# for _ in range(2):\n",
    "#     feat0 = feat0 * norm\n",
    "#     graph.ndata['h'] = feat0\n",
    "#     graph.update_all(fn.copy_u('h', 'm'),\n",
    "#                      fn.sum('m', 'h'))\n",
    "#     feat0 = graph.ndata.pop('h')\n",
    "#     feat0 = feat0 * norm\n",
    "    \n",
    "# train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "# train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "# test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "# test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "# enc = OneHotEncoder(handle_unknown='ignore')\n",
    "# enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "# one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# \"\"\" Train Logistic Regression \"\"\"\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "# lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "# logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "# logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "# ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))\n",
    "\n",
    "# train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "#                                              one_hot_labels_train, l2_reg=True)\n",
    "# val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, \n",
    "#                                                    one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# # hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "# # hess = fast_hess_cuda(train_x, logits_train_y)\n",
    "# hess = lr.hess_cuda(train_x, logits_train_y)\n",
    "\n",
    "\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=True)\n",
    "# # loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=False)\n",
    "# loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "\n",
    "# pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "# pred_infl = list(pred_infl.reshape(-1))\n",
    "# #\n",
    "# num_train = len(train_x)\n",
    "# act_infl = []\n",
    "\n",
    "# for i in tqdm(range(num_train)):\n",
    "#     lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "#     train_x_new = np.delete(train_x, i, axis = 0)\n",
    "#     train_y_new = np.delete(train_y, i)\n",
    "#     lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "#     logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "#     new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test, l2_reg = True)\n",
    "#     act_infl.append(new_ori_val_loss - ori_val_loss)\n",
    "    \n",
    "    \n",
    "# df = pd.DataFrame([pred_infl, act_infl]).T\n",
    "# df.columns = ['predicted influence', 'actual influence']\n",
    "# predicted_influence = np.array(pred_infl)\n",
    "\n",
    "# a, b = get_first_two_frequent(labels)\n",
    "# idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "# idx2 = np.where(labels[train_mask].numpy() == b)[0]\n",
    "\n",
    "\n",
    "\n",
    "# num_train = np.sum(train_mask.numpy() == 1)\n",
    "# num_perturb = int(perturb_ratio * 40)\n",
    "# predicted_influence_combined = np.concatenate([predicted_influence[idx1], predicted_influence[idx2]])\n",
    "# predicted_influence_combined_sorted = np.sort(predicted_influence_combined)[::-1]\n",
    "# threshold = predicted_influence_combined_sorted[num_perturb]\n",
    "\n",
    "# perturbed_train_y = train_y.copy()\n",
    "# new_labels_a = np.repeat(a, len(idx1))\n",
    "# new_labels_b = np.repeat(b, len(idx2))\n",
    "# assert(len(idx1) == len(idx2))\n",
    "\n",
    "# predicted_influence_combined_sorted\n",
    "\n",
    "# idx_a_to_b = np.where(predicted_influence[idx1] >= threshold)[0]\n",
    "# idx_b_to_a = np.where(predicted_influence[idx2] >= threshold)[0]\n",
    "\n",
    "# new_labels_a[idx_a_to_b] = b\n",
    "# new_labels_b[idx_b_to_a] = a\n",
    "\n",
    "# perturbed_train_y[idx1] = new_labels_a\n",
    "# perturbed_train_y[idx2] = new_labels_b\n",
    "\n",
    "# new_labels = labels.numpy().copy()\n",
    "# new_labels[train_mask] = perturbed_train_y\n",
    "# new_labels = torch.tensor(new_labels)\n",
    "\n",
    "# gcn_with_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=new_labels, \n",
    "#                                             train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "#                                            num_classes=number_classes)\n",
    "\n",
    "# gcn_without_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=labels, \n",
    "#                                             train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "#                                            num_classes=number_classes)\n",
    "\n",
    "# acc_flip = gcn_with_node_flip.train_evaluate()\n",
    "# acc_no_flip = gcn_without_node_flip.train_evaluate()\n",
    "\n",
    "# temp_acc_flip_list.append(acc_flip)\n",
    "# temp_acc_no_flip_list.append(acc_no_flip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b78e7318",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_flip_list_all = []\n",
    "acc_no_flip_list_all = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "32b512fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# perturb_ratio = 0.20\n",
    "temp_acc_flip_list = []\n",
    "temp_acc_no_flip_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "10c22619",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_ratio_list = [0.1, 0.15, 0.2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "20a22f6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.64it/s]\n",
      "100%|█████████████████████████████████████████| 120/120 [00:20<00:00,  5.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy 56.50%\n",
      "Test accuracy 70.90%\n",
      "Test accuracy 56.90%\n",
      "Test accuracy 70.10%\n",
      "Test accuracy 56.30%\n",
      "Test accuracy 69.00%\n",
      "Test accuracy 57.20%\n",
      "Test accuracy 69.90%\n",
      "Test accuracy 57.00%\n",
      "Test accuracy 70.30%\n",
      "Test accuracy 57.50%\n",
      "Test accuracy 69.90%\n",
      "Test accuracy 56.40%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 56.60%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 56.50%\n",
      "Test accuracy 70.90%\n",
      "Test accuracy 57.00%\n",
      "Test accuracy 71.30%\n",
      "Test accuracy 56.60%\n",
      "Test accuracy 70.20%\n",
      "Test accuracy 56.60%\n",
      "Test accuracy 71.30%\n",
      "Test accuracy 56.30%\n",
      "Test accuracy 71.80%\n",
      "Test accuracy 56.30%\n",
      "Test accuracy 70.80%\n",
      "Test accuracy 57.60%\n",
      "Test accuracy 71.40%\n",
      "Test accuracy 57.10%\n",
      "Test accuracy 70.30%\n",
      "Test accuracy 56.50%\n",
      "Test accuracy 70.20%\n",
      "Test accuracy 57.10%\n",
      "Test accuracy 71.40%\n",
      "Test accuracy 56.90%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 56.40%\n",
      "Test accuracy 71.20%\n",
      "Test accuracy 57.10%\n",
      "Test accuracy 70.70%\n",
      "Test accuracy 57.10%\n",
      "Test accuracy 70.60%\n",
      "Test accuracy 57.20%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 56.00%\n",
      "Test accuracy 71.00%\n",
      "Test accuracy 56.00%\n",
      "Test accuracy 71.40%\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.13it/s]\n",
      "100%|█████████████████████████████████████████| 120/120 [00:19<00:00,  6.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy 54.70%\n",
      "Test accuracy 70.30%\n",
      "Test accuracy 54.90%\n",
      "Test accuracy 71.00%\n",
      "Test accuracy 55.60%\n",
      "Test accuracy 70.80%\n",
      "Test accuracy 55.10%\n",
      "Test accuracy 69.90%\n",
      "Test accuracy 55.70%\n",
      "Test accuracy 68.90%\n",
      "Test accuracy 56.30%\n",
      "Test accuracy 71.40%\n",
      "Test accuracy 55.70%\n",
      "Test accuracy 70.90%\n",
      "Test accuracy 55.80%\n",
      "Test accuracy 70.50%\n",
      "Test accuracy 54.90%\n",
      "Test accuracy 70.80%\n",
      "Test accuracy 54.70%\n",
      "Test accuracy 71.60%\n",
      "Test accuracy 56.00%\n",
      "Test accuracy 71.50%\n",
      "Test accuracy 53.90%\n",
      "Test accuracy 69.60%\n",
      "Test accuracy 55.70%\n",
      "Test accuracy 70.50%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 71.90%\n",
      "Test accuracy 55.90%\n",
      "Test accuracy 71.20%\n",
      "Test accuracy 55.50%\n",
      "Test accuracy 71.00%\n",
      "Test accuracy 55.40%\n",
      "Test accuracy 70.80%\n",
      "Test accuracy 55.50%\n",
      "Test accuracy 69.60%\n",
      "Test accuracy 55.10%\n",
      "Test accuracy 70.10%\n",
      "Test accuracy 55.30%\n",
      "Test accuracy 69.80%\n",
      "Test accuracy 55.10%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 70.20%\n",
      "Test accuracy 55.90%\n",
      "Test accuracy 71.30%\n",
      "Test accuracy 55.00%\n",
      "Test accuracy 71.60%\n",
      "Test accuracy 55.70%\n",
      "Test accuracy 70.80%\n",
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.07it/s]\n",
      "100%|█████████████████████████████████████████| 120/120 [00:20<00:00,  5.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy 54.20%\n",
      "Test accuracy 70.90%\n",
      "Test accuracy 55.30%\n",
      "Test accuracy 70.30%\n",
      "Test accuracy 54.30%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 54.20%\n",
      "Test accuracy 70.40%\n",
      "Test accuracy 55.40%\n",
      "Test accuracy 70.20%\n",
      "Test accuracy 55.70%\n",
      "Test accuracy 71.50%\n",
      "Test accuracy 55.10%\n",
      "Test accuracy 71.80%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 72.00%\n",
      "Test accuracy 53.10%\n",
      "Test accuracy 71.80%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 71.60%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 70.60%\n",
      "Test accuracy 54.10%\n",
      "Test accuracy 70.80%\n",
      "Test accuracy 54.70%\n",
      "Test accuracy 71.60%\n",
      "Test accuracy 55.50%\n",
      "Test accuracy 71.40%\n",
      "Test accuracy 54.30%\n",
      "Test accuracy 71.20%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 70.10%\n",
      "Test accuracy 55.50%\n",
      "Test accuracy 69.40%\n",
      "Test accuracy 54.50%\n",
      "Test accuracy 70.30%\n",
      "Test accuracy 55.30%\n",
      "Test accuracy 71.10%\n",
      "Test accuracy 54.30%\n",
      "Test accuracy 70.70%\n",
      "Test accuracy 55.60%\n",
      "Test accuracy 70.90%\n",
      "Test accuracy 53.80%\n",
      "Test accuracy 71.10%\n",
      "Test accuracy 53.00%\n",
      "Test accuracy 70.60%\n",
      "Test accuracy 54.60%\n",
      "Test accuracy 70.60%\n",
      "Test accuracy 55.30%\n",
      "Test accuracy 71.50%\n"
     ]
    }
   ],
   "source": [
    "for perturb_ratio in perturb_ratio_list:\n",
    "    \n",
    "    \n",
    "    temp_acc_flip_list = []\n",
    "    temp_acc_no_flip_list = []\n",
    "\n",
    "    graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)\n",
    "    # train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, labels, number_classes, seed=15)\n",
    "\n",
    "    lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "\n",
    "    feat0 = feat.clone()\n",
    "    degs = graph.in_degrees().float().clamp(min = 1)\n",
    "    norm = torch.pow(degs, -0.5)\n",
    "    norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "    for _ in range(2):\n",
    "        feat0 = feat0 * norm\n",
    "        graph.ndata['h'] = feat0\n",
    "        graph.update_all(fn.copy_u('h', 'm'),\n",
    "                         fn.sum('m', 'h'))\n",
    "        feat0 = graph.ndata.pop('h')\n",
    "        feat0 = feat0 * norm\n",
    "\n",
    "    train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "    train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "    test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "    test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "    enc = OneHotEncoder(handle_unknown='ignore')\n",
    "    enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "    one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "    one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "    \"\"\" Train Logistic Regression \"\"\"\n",
    "    lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "    lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "    logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "    logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "    ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "    numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))\n",
    "\n",
    "    train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "                                                 one_hot_labels_train, l2_reg=True)\n",
    "    val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, \n",
    "                                                       one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "    # hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "    # hess = fast_hess_cuda(train_x, logits_train_y)\n",
    "    hess = lr.hess_cuda(train_x, logits_train_y)\n",
    "\n",
    "\n",
    "    loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=True)\n",
    "    # loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=False)\n",
    "    loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "\n",
    "    pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "    pred_infl = list(pred_infl.reshape(-1))\n",
    "    #\n",
    "    num_train = len(train_x)\n",
    "    act_infl = []\n",
    "\n",
    "    for i in tqdm(range(num_train)):\n",
    "        lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "        train_x_new = np.delete(train_x, i, axis = 0)\n",
    "        train_y_new = np.delete(train_y, i)\n",
    "        lr_new.fit(train_x_new, train_y_new)\n",
    "\n",
    "        logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "\n",
    "\n",
    "        new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test, l2_reg = True)\n",
    "        act_infl.append(new_ori_val_loss - ori_val_loss)\n",
    "\n",
    "\n",
    "    df = pd.DataFrame([pred_infl, act_infl]).T\n",
    "    df.columns = ['predicted influence', 'actual influence']\n",
    "    predicted_influence = np.array(pred_infl)\n",
    "\n",
    "    a, b = get_first_two_frequent(labels)\n",
    "    idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "    idx2 = np.where(labels[train_mask].numpy() == b)[0]\n",
    "\n",
    "\n",
    "\n",
    "    num_train = np.sum(train_mask.numpy() == 1)\n",
    "    num_perturb = int(perturb_ratio * 40)\n",
    "    predicted_influence_combined = np.concatenate([predicted_influence[idx1], predicted_influence[idx2]])\n",
    "    predicted_influence_combined_sorted = np.sort(predicted_influence_combined)[::-1]\n",
    "    threshold = predicted_influence_combined_sorted[num_perturb]\n",
    "\n",
    "    perturbed_train_y = train_y.copy()\n",
    "    new_labels_a = np.repeat(a, len(idx1))\n",
    "    new_labels_b = np.repeat(b, len(idx2))\n",
    "    assert(len(idx1) == len(idx2))\n",
    "\n",
    "    predicted_influence_combined_sorted\n",
    "\n",
    "    idx_a_to_b = np.where(predicted_influence[idx1] >= threshold)[0]\n",
    "    idx_b_to_a = np.where(predicted_influence[idx2] >= threshold)[0]\n",
    "\n",
    "    new_labels_a[idx_a_to_b] = b\n",
    "    new_labels_b[idx_b_to_a] = a\n",
    "\n",
    "    perturbed_train_y[idx1] = new_labels_a\n",
    "    perturbed_train_y[idx2] = new_labels_b\n",
    "\n",
    "    new_labels = labels.numpy().copy()\n",
    "    new_labels[train_mask] = perturbed_train_y\n",
    "    new_labels = torch.tensor(new_labels)\n",
    "\n",
    "    for _ in range(25):\n",
    "        gcn_with_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=new_labels, \n",
    "                                                    train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                                   num_classes=number_classes)\n",
    "\n",
    "        gcn_without_node_flip = gcn_with_node_flipping(graph= graph, features=feat, new_labels=labels, \n",
    "                                                    train_mask=train_mask, val_mask=val_mask, test_mask=test_mask,\n",
    "                                                   num_classes=number_classes)\n",
    "\n",
    "        acc_flip = gcn_with_node_flip.train_evaluate()\n",
    "        acc_no_flip = gcn_without_node_flip.train_evaluate()\n",
    "\n",
    "        temp_acc_flip_list.append(acc_flip)\n",
    "        temp_acc_no_flip_list.append(acc_no_flip)\n",
    "\n",
    "    \n",
    "    acc_flip_list_all.append(temp_acc_flip_list)\n",
    "    acc_no_flip_list_all.append(temp_acc_no_flip_list)\n",
    "#     acc_flip_list_all = []\n",
    "#     acc_no_flip_list_all = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1bfc7ccb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5464800000000001"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(temp_acc_flip_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9d246b8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.56748, 0.55304, 0.54648])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(np.array(acc_flip_list_all), axis =1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fde5f027",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(np.array(acc_flip_list_all)).to_csv('result_flip_attack/pubmed_public_final.csv', index = False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
