{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0ea11bfb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from dataset import load_graph_dataset\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ade1d263",
   "metadata": {},
   "outputs": [],
   "source": [
    "def index_to_mask(index, size):\n",
    "    mask = torch.zeros(size, dtype=torch.bool, device=index.device)\n",
    "    mask[index] = 1\n",
    "    return mask\n",
    "def random_splits_label_flip_attack(graph, labels, num_classes, seed):\n",
    "    # Set new random planetoid splits:\n",
    "    # * 20 * num_classes labels for training\n",
    "    # * 500 labels for validation\n",
    "    # * 1000 labels for testing\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    indices = []\n",
    "\n",
    "    for i in range(num_classes):\n",
    "        index = (labels == i).nonzero().view(-1)\n",
    "        index = index[torch.randperm(index.size(0))]\n",
    "        indices.append(index)\n",
    "\n",
    "    train_index = torch.cat([i[:20] for i in indices], dim=0)\n",
    "\n",
    "    rest_index = torch.cat([i[20:] for i in indices], dim=0)\n",
    "    rest_index = rest_index[torch.randperm(rest_index.size(0))]\n",
    "    \n",
    "    train_mask = index_to_mask(train_index, size=graph.num_nodes())\n",
    "    val_mask = index_to_mask(rest_index[:500], size=graph.num_nodes())\n",
    "    test_mask = index_to_mask(rest_index[500:1500], size=graph.num_nodes())\n",
    "\n",
    "    return train_mask, val_mask, test_mask\n",
    "\n",
    "def get_first_two_frequent(labels):\n",
    "    class_counts = np.bincount(labels)\n",
    "    a = np.argsort(class_counts)[-1]\n",
    "    b = np.argsort(class_counts)[-2]\n",
    "    return a, b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d301f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_set = 'cora'\n",
    "# # l2_term = 0.01\n",
    "l2_term = 0.01\n",
    "# # batch_edges = 1\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'pubmed'\n",
    "# l2_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "data_set = 'citeseer'\n",
    "# l2_term = 0.003\n",
    "num_layer = 2\n",
    "\n",
    "# data_set = 'reddit'\n",
    "# l2_term = 0.99\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a0ffc0",
   "metadata": {},
   "source": [
    "##### 1, load data, convert to one hot encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "01f74bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"retrieve all edges connected to traing nodes\"\"\"\n",
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask):\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        for to_index_e in to_index_list:\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "23d92dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# some_seed_list = [15, 42, 123, 211]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2060d0d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)\n",
    "# train_mask, val_mask, test_mask = random_splits_label_flip_attack(graph, \n",
    "#                                                                   labels, number_classes, seed=some_seed_list[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0b0daa19",
   "metadata": {},
   "outputs": [],
   "source": [
    "a, b = get_first_two_frequent(labels[test_mask])\n",
    "idx1 = np.where(labels[train_mask].numpy() == a)[0]\n",
    "idx2 = np.where(labels[train_mask].numpy() == b)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "90ee0725",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_new = np.concatenate([idx1, idx2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1fadbf2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d1d1bcfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_num_labels = np.max(labels.numpy()) + 1\n",
    "\n",
    "number_of_training_data = np.sum(train_mask.numpy())\n",
    "\n",
    "pred_infl_mat = np.zeros([number_of_training_data, total_num_labels +1])\n",
    "act_infl_mat = np.zeros([number_of_training_data, total_num_labels +1])\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ffbe77f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 120/120 [00:00<00:00, 22687.24it/s]\n"
     ]
    }
   ],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "\n",
    "acctual_influence_node_features = []\n",
    "acctual_influence_edges = []\n",
    "\n",
    "predict_influence_node_features = []\n",
    "predict_influence_edges = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a42eee20",
   "metadata": {},
   "outputs": [],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b436da6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "val_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc29b32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bfd6decb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                    | 0/40 [00:00<?, ?it/s]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.64it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.30it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.23it/s]\u001b[A\n",
      "  2%|█                                           | 1/40 [00:41<27:16, 41.97s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 460.31it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.23it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.95it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "  5%|██▏                                         | 2/40 [01:19<25:03, 39.58s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.47it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.86it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "  8%|███▎                                        | 3/40 [01:57<23:54, 38.78s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.27it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.77it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      " 10%|████▍                                       | 4/40 [02:35<23:05, 38.49s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.32it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.29it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      " 12%|█████▌                                      | 5/40 [03:13<22:21, 38.33s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.67it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.59it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 15%|██████▌                                     | 6/40 [03:51<21:39, 38.21s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.49it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 460.26it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      " 18%|███████▋                                    | 7/40 [04:29<20:58, 38.15s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 460.06it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.52it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 20%|████████▊                                   | 8/40 [05:07<20:19, 38.11s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.80it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.32it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 22%|█████████▉                                  | 9/40 [05:45<19:40, 38.08s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.60it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.57it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 25%|██████████▊                                | 10/40 [06:23<19:02, 38.07s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.54it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.36it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 28%|███████████▊                               | 11/40 [07:01<18:23, 38.06s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.45it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.23it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      " 30%|████████████▉                              | 12/40 [07:39<17:45, 38.04s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.44it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.45it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 32%|█████████████▉                             | 13/40 [08:17<17:06, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.65it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.80it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|███████████████                            | 14/40 [08:55<16:28, 38.03s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.35it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.41it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 38%|████████████████▏                          | 15/40 [09:33<15:50, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.18it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.29it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 40%|█████████████████▏                         | 16/40 [10:11<15:11, 37.99s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.37it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.39it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 42%|██████████████████▎                        | 17/40 [10:49<14:33, 37.99s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.37it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.27it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 45%|███████████████████▎                       | 18/40 [11:27<13:55, 38.00s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.55it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.15it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      " 48%|████████████████████▍                      | 19/40 [12:05<13:18, 38.01s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.64it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.53it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 50%|█████████████████████▌                     | 20/40 [12:43<12:39, 37.99s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.29it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.53it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 52%|██████████████████████▌                    | 21/40 [13:21<12:01, 37.99s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.28it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.48it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 55%|███████████████████████▋                   | 22/40 [13:59<11:24, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.43it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 458.57it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      " 57%|████████████████████████▋                  | 23/40 [14:37<10:45, 38.00s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.18it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.40it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 60%|█████████████████████████▊                 | 24/40 [15:15<10:08, 38.01s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.16it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.50it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 62%|██████████████████████████▉                | 25/40 [15:53<09:30, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.42it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.48it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      " 65%|███████████████████████████▉               | 26/40 [16:32<08:52, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.66it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.33it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 68%|█████████████████████████████              | 27/40 [17:10<08:14, 38.06s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.33it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.41it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 70%|██████████████████████████████             | 28/40 [17:48<07:36, 38.07s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.31it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.41it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 72%|███████████████████████████████▏           | 29/40 [18:26<06:58, 38.03s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.38it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.40it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 75%|████████████████████████████████▎          | 30/40 [19:04<06:20, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.21it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.19it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.31it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 78%|█████████████████████████████████▎         | 31/40 [19:42<05:42, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.11it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.40it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 80%|██████████████████████████████████▍        | 32/40 [20:20<05:04, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.52it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.33it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 82%|███████████████████████████████████▍       | 33/40 [20:58<04:26, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.27it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.79it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 85%|████████████████████████████████████▌      | 34/40 [21:36<03:48, 38.02s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.48it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.41it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      " 88%|█████████████████████████████████████▋     | 35/40 [22:14<03:10, 38.03s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.30it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.84it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 90%|██████████████████████████████████████▋    | 36/40 [22:52<02:31, 38.00s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.31it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.47it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 92%|███████████████████████████████████████▊   | 37/40 [23:30<01:54, 38.07s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.57it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.70it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 95%|████████████████████████████████████████▊  | 38/40 [24:08<01:16, 38.14s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.68it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.64it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.20it/s]\u001b[A\n",
      " 98%|█████████████████████████████████████████▉ | 39/40 [24:46<00:38, 38.12s/it]\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 460.48it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "\n",
      "  0%|                                                   | 0/119 [00:00<?, ?it/s]\u001b[A\n",
      " 53%|█████████████████████▋                   | 63/119 [00:00<00:00, 459.79it/s]\u001b[A\n",
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.21it/s]\u001b[A\n",
      "100%|███████████████████████████████████████████| 40/40 [25:25<00:00, 38.13s/it]\n"
     ]
    }
   ],
   "source": [
    "# for flipped_label in range(total_num_labels):\n",
    "\n",
    "    \n",
    "\n",
    "    \n",
    "for k in tqdm(range(len(train_node_idx))):\n",
    "# for k in [102]:\n",
    "#     for k in range(10):\n",
    "\n",
    "\n",
    "#         train_y[k] = flipped_label\n",
    "    train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "    train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "    val_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "    val_y = labels[test_mask].numpy().astype(np.float32)\n",
    "    \n",
    "    for flipped_label in range(total_num_labels):\n",
    "\n",
    "    # convert to one-hot labels\n",
    "        \n",
    "        train_y[k] = flipped_label\n",
    "        \n",
    "        enc = OneHotEncoder(handle_unknown='ignore')\n",
    "        enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "        one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "        one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\n",
    "        lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "        lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "        logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "        logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "        ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "        # numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "        # # set l2_reg to False, verify the correctness of calculations\n",
    "        # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "        val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                            logits_val_y_origin,\n",
    "                                                                            one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "        hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "        loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "        loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "        del hess\n",
    "\n",
    "\n",
    "        node_id = train_node_idx.numpy()[k]\n",
    "        nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "\n",
    "        # 2, remove the edges, calculate the perturbated feature\n",
    "        nis.remove_edges_sgc()\n",
    "        feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "        extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "\n",
    "\n",
    "        extra_index_train = torch.tensor(\n",
    "            [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "\n",
    "        extra_index_train_in_train = [\n",
    "            np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "\n",
    "        # 1, we need to remove the changed node feature from the perturbated feature, \n",
    "        # let it not added to the original feature\n",
    "\n",
    "\n",
    "        \"\"\"modified node features\"\"\"\n",
    "        extra_index_train_remove_node = extra_index_train.copy()\n",
    "        relative_node_id = np.where(extra_index_train_remove_node == node_id)[0]\n",
    "        extra_index_train_remove_node = np.delete(extra_index_train_remove_node, relative_node_id)\n",
    "        feat_to_be_added = feat_removed1[extra_index_train_remove_node].numpy()\n",
    "\n",
    "        \"\"\"index corresponding to modified node features\"\"\"\n",
    "        perturb_index = extra_index_train_in_train\n",
    "        added_index = perturb_index.copy()\n",
    "        added_index.remove(k)\n",
    "\n",
    "\n",
    "\n",
    "        train_x_new = feat_to_be_added\n",
    "        train_y_new = train_y[added_index]\n",
    "\n",
    "        train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "        train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "\n",
    "\n",
    "        one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "        logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "        train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                                logits_train_y_origin_0, \n",
    "                                                one_hot_labels_train_0, l2_reg = True)\n",
    "\n",
    "\n",
    "        pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        weight_3 = np.ones(len(train_x_orig))\n",
    "        weight_3[perturb_index] = 0 # 1...0...11\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "        train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "        train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "\n",
    "        lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "\n",
    "        logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "        new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "        predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "        acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)\n",
    "\n",
    "        p_if_temp = np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):])\n",
    "        a_if_temp = new_ori_val_loss_2 - ori_val_loss\n",
    "\n",
    "\n",
    "        pred_infl_mat[k][flipped_label] =p_if_temp\n",
    "        act_infl_mat[k][flipped_label] = a_if_temp\n",
    "        \n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ae716ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pred = pd.DataFrame(pred_infl_mat)\n",
    "df_pred.to_csv('result_flip_attack/citeseer_influence_flip/public_split_citeseer_pred_infl_flip_class.csv')\n",
    "\n",
    "df_act = pd.DataFrame(act_infl_mat)\n",
    "df_act.to_csv('result_flip_attack/citeseer_influence_flip/public_split_citeseer_act_infl_flip_class.csv')\n",
    "\n",
    "\n",
    "# df_pred = pd.DataFrame(pred_infl_mat)\n",
    "# df_pred.to_csv('result_flip_attack/single_cora_pred_infl_flip_class.csv')\n",
    "\n",
    "# df_act = pd.DataFrame(act_infl_mat)\n",
    "# df_act.to_csv('result_flip_attack/single_cora_act_infl_flip_class.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6bd7624",
   "metadata": {},
   "outputs": [],
   "source": [
    "# k, flipped_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eae68bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_val_y_new_2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "237b1f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_x.shape, lr_new_2.model.coef_.T.shape, lr_new_2.model.intercept_.shape, one_hot_labels_val.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d19dcd7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_influence_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12a65eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "low_limit = -20\n",
    "up_limit = 20\n",
    "\n",
    "sns.set_theme()\n",
    "sns.set_style(\"whitegrid\")\n",
    "low_limit = -20\n",
    "up_limit = 20\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "plt.plot(x, x, color=\"grey\", alpha=0.25, zorder=0)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 20, label = 'A', linewidths=0)\n",
    "plt.ticklabel_format(style=\"sci\", scilimits=(-4, 4))\n",
    "plt.axis('square')\n",
    "plt.xlabel('Act. Influence')\n",
    "plt.ylabel('Pred. Influence')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027c26ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_pred = pd.DataFrameme(pred_infl_mat)\n",
    "# df_pred.to_csv('result_flip_attack/cora_pred_infl_flip_class.csv')\n",
    "\n",
    "# df_act = pd.DataFrameme(act_infl_mat)\n",
    "# df_act.to_csv('result_flip_attack/cora_act_infl_flip_class.csv')\n",
    "\n",
    "\n",
    "# df_pred = pd.DataFrame(pred_infl_mat)\n",
    "# df_pred.to_csv('result_flip_attack/' + data_set + '_pred_infl_flip_class_'+ str(some_seed_list[1]) +'.csv')\n",
    "\n",
    "# df_act = pd.DataFrame(act_infl_mat)\n",
    "# df_act.to_csv('result_flip_attack/' + data_set + '_act_infl_flip_class_'+ str(some_seed_list[1]) +'.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40de2107",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80d5cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a2b4eca",
   "metadata": {},
   "outputs": [],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68da4ba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data1 = pd.DataFrame([time_infl[i] / time_retrain[i] for i in range(len(time_infl))])\n",
    "# data1.to_csv('running time/citeseer_running_time.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dc90140",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([acctual_influence_1, predict_influence_1]).T\n",
    "df.columns = ['acctual_influence', 'predict_influence']\n",
    "# df.to_csv('complete_node/' +data_set + '.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3398243c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.read_csv('complete_node/cora_flip.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7805e6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "added_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569a71b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea4e73e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16138629",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x_orig.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa800ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3307b612",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(weight_3 == 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b162f43a",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceaccc86",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_y_new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfff86eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_y_orig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1e6f642",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_y_orig[k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5362124",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
