{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d2e2b33b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import math\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import torch.nn.functional as F\n",
    "import tensorflow.compat.v1 as tf\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.utils import check_array\n",
    "from dataset import load_graph_dataset\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aeee0344",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/zizhang/Desktop/Projects/Project6_influence_function/graph_influence_function/experiments'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "aa453292",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5744cc1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_set = 'cora'\n",
    "# # l2_term = 0.01\n",
    "# l2_term = 0.01\n",
    "# # batch_edges = 1\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'pubmed'\n",
    "# l2_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "data_set = 'citeseer'\n",
    "l2_term = 0.003\n",
    "num_layer = 2\n",
    "\n",
    "# data_set = 'reddit'\n",
    "# l2_term = 0.99\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "33162a0d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4841385b",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9322e579",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "val_y = labels[val_mask].numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ab4a7f40",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "# lr.fit(train_x, train_y, sample_weight=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6a398dcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask, seed_val=10):\n",
    "    torch.manual_seed(seed_val)\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        \n",
    "#         print(f_index)\n",
    "#         random_index = torch.randint(0, len(to_index_list), (1,))[0]\n",
    "        for to_index_e in to_index_list:\n",
    "#             print(to_index_list)\n",
    "#             print(to_index_e.item())\n",
    "#             j = to_index_list[to_index_e[0]]\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cc03ddc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 120/120 [00:00<00:00, 20989.89it/s]\n"
     ]
    }
   ],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "f_l, t_l = from_indexes, to_indexes\n",
    "\n",
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# df = pd.read_csv('result_data/' + data_set + '_edge_influence.csv', header = None)\n",
    "# df = df.loc[df[0] != 0]\n",
    "# f_l = torch.tensor(df[2].values.astype(int))\n",
    "# t_l = torch.tensor(df[3].values.astype(int))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7e9f9334",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# split_side = int(len(f_l) / batch_edges)\n",
    "# total_edges = np.arange(len(f_l))\n",
    "# np.random.shuffle(total_edges)\n",
    "# new_index = np.array_split(total_edges, split_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5718ca75",
   "metadata": {},
   "outputs": [],
   "source": [
    "changed_list = []\n",
    "changed_list_change = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e320ad5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.64it/s]\n"
     ]
    }
   ],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# train the original data\n",
    "# calculate the hessian matrix\n",
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "\n",
    "###### write out hessian matrix\n",
    "\n",
    "# hess = cp.asnumpy(hess)\n",
    "\n",
    "# pd.DataFrame(hess).to_csv(\"reddit/hess.csv\", index = False)\n",
    "\n",
    "######\n",
    "\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "15c40ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # convert to one-hot labels\n",
    "# enc = OneHotEncoder(handle_unknown='ignore')\n",
    "# enc.fit(train_y.reshape(-1, 1))\n",
    "# one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# # train the original data\n",
    "# # calculate the hessian matrix\n",
    "# lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "# lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "# logits_test_y_origin = test_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_test_y_origin, one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# # numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y_origin, axis=1))\n",
    "\n",
    "# # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "# val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(test_x, \n",
    "#                                                                     logits_test_y_origin,\n",
    "#                                                                     one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "# hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "# loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "# del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e5abf40c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 12431/12431 [49:07<00:00,  4.22it/s]\n"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "# for k in tqdm(range(len(train_node_idx))):\n",
    "for k in tqdm(range(len(f_l))):\n",
    "# for k in tqdm(range(len(new_index))):\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[new_index[k]], to_index=t_l[new_index[k]])\n",
    "#     eis.remove_edges_sgc_from_influence()\n",
    "    eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "    eis.remove_edges_sgc_from_influence()\n",
    "    feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "#     node_id = train_node_idx.numpy()[k]\n",
    "#     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "#     nis.remove_edges_sgc()\n",
    "#     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    if extra_index_train == []:\n",
    "        predict_influence_1.append(0.0)\n",
    "        acctual_influence_1.append(0.0)\n",
    "        continue\n",
    "    \n",
    "    feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[perturb_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "#     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "#     weight_1 = np.ones(len(train_x_orig))\n",
    "#     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "#     weight_2 = np.ones(len(train_x_orig))\n",
    "#     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "#     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "#     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "#     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "#     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "#     assert(np.allclose(train_x_delete_1, train_x))\n",
    "#     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "#     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "#     logits_val_y_new_1 = val_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "#     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_val_y_new_1, one_hot_labels_val)\n",
    "    \n",
    "    \n",
    "    \n",
    "    lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "    train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "    train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "    lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "    logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "    new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "    acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f258a0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "30badc33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# # for k in tqdm(range(len(train_node_idx))):\n",
    "# for k in tqdm(range(len(f_l))):\n",
    "\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "#     eis.remove_edges_sgc()\n",
    "#     feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "# #     node_id = train_node_idx.numpy()[k]\n",
    "# #     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "# #     nis.remove_edges_sgc()\n",
    "# #     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "#     extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "#     extra_index_train = torch.tensor(\n",
    "#         [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "#     extra_index_train_in_train = [\n",
    "#         np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "#     if extra_index_train != []:\n",
    "#         predict_influence_1.append(0.0)\n",
    "#         acctual_influence_1.append(0.0)\n",
    "#         continue\n",
    "    \n",
    "#     feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "#     perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "#     train_x_new = feat_to_be_added\n",
    "#     train_y_new = train_y[perturb_index]\n",
    "    \n",
    "#     train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "#     train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "#     one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "#     logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "#     train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "#                                             logits_train_y_origin_0, \n",
    "#                                             one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "#     pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "# #     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "# #     weight_1 = np.ones(len(train_x_orig))\n",
    "# #     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "# #     weight_2 = np.ones(len(train_x_orig))\n",
    "# #     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "# #     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "#     weight_3 = np.ones(len(train_x_orig))\n",
    "#     weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "# #     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "# #     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "# #     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "# #     assert(np.allclose(train_x_delete_1, train_x))\n",
    "# #     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "# #     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "# #     logits_test_y_new_1 = test_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "# #     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_test_y_new_1, one_hot_labels_test)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "#     train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "#     train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "#     lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "#     logits_test_y_new_2 = test_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "#     new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_test_y_new_2, one_hot_labels_test, l2_reg = True)\n",
    "    \n",
    "#     predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "#     acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a663d124",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZwAAAEKCAYAAAAmfuNnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA1LUlEQVR4nO3dd3hUZfbA8e9J6KH3GjqEjhCKYMGOqCiiu+iufWXZ1Z+76+4KGHvvq1ssiL0rRRARQVbFVVEQJQkQWmgh9JIEQvr5/XFvZBgnYYC5M5PkfJ5nnpl77zt3Tm7IHO5733teUVWMMcYYr8VEOgBjjDFVgyUcY4wxYWEJxxhjTFhYwjHGGBMWlnCMMcaEhSUcY4wxYRHRhCMiL4vIThFJ9VnXWEQWiMha97lRGe8dKSKrRWSdiEwKX9TGGGOOR6TPcF4FRvqtmwQsVNWuwEJ3+QgiEgv8Bzgf6AlcISI9vQ3VGGPMiYhowlHVRcBev9UXA6+5r18DLgnw1sHAOlVNV9UC4F33fcYYY6JUtUgHEEALVd0GoKrbRKR5gDZtgC0+yxnAkEA7E5HxwHiAuLi4gQkJCSEO1xhjKp+CohK27j/EgfwiCrav262qzU50n9GYcIIhAdYFrNGjqlOAKQCJiYm6dOlSL+MyxpgKrbhEef3bjTz+6WpaAhPPT+CaYR03hWLf0ZhwdohIK/fsphWwM0CbDKCdz3JbIDMs0RljTCW1bmcOE6en8MOmfZzerRkPXdqHNg1rc02I9h+NCWc2cA3wiPs8K0CbJUBXEekIbAXGAVeGLUJjjKlECotLeOHL9fxz4Trq1IzlqV/1Y8xJbRAJ1Jl0/CKacETkHWAE0FREMoC7cRLN+yJyA7AZuNxt2xqYqqqjVLVIRG4GPgVigZdVdUUkfgZjjKnIUrdm8fdpyazals0FfVtxz0W9aFavpiefFdGEo6pXlLHprABtM4FRPstzgbkehWaMMZVaXmExT3+2lhe/SqdJXA1euGog5/Vq6elnRmOXmjHGGA99l76HSTNS2LD7IL9ObMftF/SgQe3qnn+uJRxjjKkicvIKeWzeat5YvIl2jWvz1u+GMLxL07B9viUcY4ypAj5fvZOkGSlsy87j+uEd+dt53ahTI7wpwBKOMcZUYvsOFnD/nJXM+HErXZvXZfofhjEgPmCJSs9ZwjHGmEpIVfk4ZRt3z1pB1qFCbjmrKzed0Zma1WIjFpMlHGOMqWR2ZOdx54epzF+5g75tG/Dm74bQo1X9SIdlCccYYyoLVeX9pVt44ONVFBSVcPuoBK4f3pFqsZGeGMBhCccYYyqBzXtymTwzma/X7WFIx8Y8OrYvHZrGRTqsI1jCMcaYCqy4RHn1m4088elqYmOEBy7pzZWD44mJCW1ZmlCwhGOMMRXUmh053DYtmZ+27OfMhOY8OKY3rRrUjnRYZbKEY4wxFUxBUQnPf7mef/13LXVrVuOZcf0Z3a91yItthpolHGOMqUCWb9nPxOnJpG3PYXS/1tx9UU+a1PWm2GaoWcIxxpgK4FBBMU9/toYXv0qnWb2aTL06kbN7toh0WMfEEo4xxkS5b9fvYfKMZDbuyeWKwfFMHpVA/VreF9sMNUs4xhgTpbLzCnnkkzTe/m4z7ZvU4e0bhzCsc/iKbYaaJRxjjIlC/03bwe0zUtmZk8eNp3bk1nO6U7tG5MrShIIlHGOMiSJ7DuRz35yVzPopk+4t6vH8VQPp365hpMMKCUs4xhgTBVSV2cszufejleTkFfKns7py0xldqFEtOsrShEJUJhwR6Q6857OqE3CXqj7t02YEMAvY4K6aoar3hSlEY4wJmW1Zh7hjZioL03bSr11DHhvbl+4t60U6rJCLyoSjqquB/gAiEgtsBWYGaPqVql4YxtCMMSZkSkqUd5ds4eG5qygsKeGOC3pw3fCOxEZhWZpQiMqE4+csYL2qbop0IMYYEyobdx9k0oxkFqfv5eROTXhkbB/aN4muYpuhVhESzjjgnTK2nSwiy4FM4G+quiJ8YRljzLErKi7h5a838OT8NdSIjeHhS/swblC7qC9LEwpRnXBEpAYwGpgcYPMyoL2qHhCRUcCHQNcA+xgPjAeIj4/3LlhjjDmKtO3ZTJyWzPKMLM7u0ZwHLulDywa1Ih1W2ER1wgHOB5ap6g7/Daqa7fN6rog8KyJNVXW3X7spwBSAxMRE9TpgY4zxl19UzH8+X8+zn6+jQe3q/OuKk7iwb6sqcVbjK9oTzhWU0Z0mIi2BHaqqIjIYiAH2hDM4Y4w5mh8372Pi9GTW7DjAmJPacOeFPWkcVyPSYUVE1CYcEakDnAP83mfdBABVfR64DPiDiBQBh4BxqmpnMMaYqJBbUMST89fw8tcbaFm/Fi9fm8iZCRWr2GaoRW3CUdVcoInfuud9Xv8b+He44zLGmKP5Zt1uJs1IYfPeXH47NJ6JIxOoVwGLbYZa1CYcY4ypaLIOFfLw3FW8u2QLHZvG8d74oQzp1OTob6wiLOEYY0wIzF+xnTs+TGX3gXx+f3on/nJ2N2pVr9jFNkPNEo4xxpyA3QfyuWf2CuYkbyOhZT2mXpNI37YNIx1WVLKEY4wxx0FV+fCnrdz70Upy84u59ZxuTDi9c6UqthlqlnCMMeYYZe4/RNLMFD5fvYuT4p1im11bVL5im6FmCccYY4JUUqK89f1mHv0kjeIS5a4Le3LNsA6VtthmqFnCMcaYIGzYfZCJ05P5fsNeTunSlIcv7UO7xnUiHVaFYgnHGGPKUVRcwtT/beAfC9ZQo1oMj43ty+WJbatcWZpQsIRjjDFlWJmZzW3Tl5O6NZvzerXg/ot707x+1Sm2GWqWcIwxxk9+UTH//u86nvtiPQ3rVOfZ3wzg/N4t7azmBFnCMcYYHz9scoptrtt5gEsHtOHOC3rSqIoW2ww1SzjGGAMczC/iifmrefWbjbRuUJtXrxvEiO7NIx1WpWIJxxhT5X21dheTZ6SQse8QVw1tz8TzE6hb074eQ82OqDGmysrKLeTBuSt5f2kGnZrG8f7vT2Zwx8aRDqvSsoRjjKmS5qVu585Zqew9WMAfR3TmlrO6WrFNj1nCMcZUKbtynGKbH6dso2er+rxy7SB6t2kQ6bCqBEs4xpgqQVWZsWwr981ZyaGCYv5+XnfGn9aJ6rFWbDNcjppwRORyYJ6q5ojIHcAA4AFVXeZ5dMYYEwIZ+3K5fWYqi9bsYmD7Rjw6ti9dmteNdFhVTjBnOHeq6gcicgpwHvAE8BwwxNPIjDHmBJWUKG9+t4lHP0lDgXtH9+Kqoe2JsWKbERFMwil2ny8AnlPVWSJyj3chOURkI5Djfn6Rqib6bRfgGWAUkAtca2ddxphS63cdYNL0ZJZs3Mdp3Zrx0JjetG1kxTYjKZiEs1VEXgDOBh4VkZpAuDo9z1DV3WVsOx/o6j6GYGddxhigsLiEF79K5+nP1lK7eixPXN6PsQPaWFmaKBBMwvkVMBJ4QlX3i0gr4O/ehhWUi4HXVVWBxSLSUERaqeq2SAdmjImM1K1ZTJyezIrMbEb1ack9o3vRvJ4V24wWwSScVsDHqpovIiOAvsDrXgblUmC+iCjwgqpO8dveBtjis5zhrjsi4YjIeGA8QHx8vHfRGmMiJq+wmH8uXMsLi9JpVKcGz/92ACN7t4p0WMZPMAlnOpAoIl2Al4DZwNs41068NFxVM0WkObBARNJUdZHP9kDnx/qLFU6imgKQmJj4i+3GmIptyca9TJyeTPqug1w+sC13XNCTBnWqRzosE0AwCadEVYtE5FLgaVX9l4j86HVgqprpPu8UkZnAYMA34WQA7XyW2wKZXsdljIkOB/KLeHxeGq8v3kTrBrV5/frBnNatWaTDMuUIJuEUisgVwNXARe46T//7ICJxQIx7708ccC5wn1+z2cDNIvIuzmCBLLt+Y0zV8OWaXdw+I4XMrENcc3IH/n5ed+Ks2GbUC+Y3dB0wAXhQVTeISEfgTW/DogUw0x1VUg14W1XnicgEAFV9HpiL0623DmdY9HUex2SMibD9uQXcN2clM5ZtpXOzOKZNOJmB7a3YZkUhziCvozQSqQF0cxdXq2qhp1F5JDExUZcuXRrpMIwxx2FuyjbumpXK/txCJpzemZvP7GLFNsNERH7wvxfyeART2mYE8BqwEedCfTsRucbvAr4xxnhiZ3Yed81awbwV2+ndpj6vXT+YXq2t2GZFFEyX2pPAuaq6GkBEugHvAAO9DMwYU7WpKh/8kMEDc1aSX1TCpPMT+N0pHalmxTYrrGASTvXSZAOgqmtExMYcGmM8s2VvLpNnpPC/dbsZ3KExj4ztQ6dmVmyzogsm4SwVkZeAN9zl3wA/eBeSMaaqKi5RXv92I4/NW02MwP2X9OY3g+Ot2GYlEUzC+QNwE3ALzjWcRcCzXgZljKl61u3M4bZpySzbvJ8R3Zvx4Jg+tGlYO9JhmRA6asJR1XzgKfdhjDEhVVhcwgtfruefC9dRp2Ys//h1Py7pb8U2K6MyE46IpBCgVEwpVe3rSUTGmCojJSOLv09bTtr2HC7o24p7R/eiad2akQ7LeKS8M5wLwxaFMaZKySss5unP1vLiV+k0iavBlKsGcm6vlpEOy3iszISjqpvCGYgxpmr4Ln0Pk2aksGH3QcYNasfkUT1oUNsGvlYFVnzIGBMWOXmFPDovjTcXb6Zd49q89bshDO/SNNJhmTCyhGOM8dznaTtJmpnCtuw8bjilI389txt1atjXT1Vjv3FjjGf2Hizg/jkrmfnjVro2r8v0PwxjQHyjSIdlIiSYWmrDgXuA9m57AVRVO3kbmjGmolJVPk7Zxt2zVpB1qJBbzurKTWd0pmY1K7ZZlQVzhvMS8Bec6gLF3oZjjKnodmTncceHqSxYuYO+bRvw5u+G0KNV/UiHZaJAMAknS1U/8TwSY0yFpqq8v3QLD3y8ioKiEpJG9eC64R2s2Kb5WTAJ53MReRyYAeSXrlTVZZ5FZYypUDbvyWXSjGS+Wb+HIR0b8+jYvnRoGhfpsEyUCSbhDHGffSffUeDM0IdjjKlIikuUV77ewBPzV1MtJoaHxvRh3KB2VmzTBBRMLbUzwhGIMaZiWbPDKbb505b9nJnQnAfH9KZVAyu2acpWXi2136rqmyJya6DtqupZMU8RaQe8DrQESoApqvqMX5sRwCxgg7tqhqre51VMxhhHQVEJz32xnn9/vpZ6tarzzLj+jO7X2optmqMq7wyntAO2XjgC8VME/FVVl4lIPeAHEVmgqiv92n2lqlbzzZgwWb5lPxOnJ5O2PYeL+rXmnot60sSKbZoglVdL7QX3+d7whfPzZ28Dtrmvc0RkFdAG8E84xpgwOFRQzD8+W8PUr9JpXq8WU69O5OyeLSIdlqlgor7SgIh0AE4Cvguw+WQRWQ5kAn9T1RUB3j8eGA8QHx/vYaTGVB5JSTB7NoweDRdev4dJM5LZtCeXK4fEM+n8BOrXsmKb5thFdcIRkbrAdODPqprtt3kZ0F5VD4jIKOBDoKv/PlR1CjAFIDExscz5fYwxjqQkeOghkBqFZP6QxlvFm2nfpA5v3ziEYZ2t2KY5flGbcESkOk6yeUtVZ/hv901AqjpXRJ4VkaaqujuccRpT2cyeDbU776DxeanExuVRfX0n5t3Xjdo1rCyNOTFHvQVYRFqIyEsi8om73FNEbvAyKHGGu7wErCprNJyItHTbISKDcX6WPV7GZUxlt+dAPo0v/JHmly2lJK86298czq+69bBkY0IimJoTrwKfAq3d5TXAnz2Kp9Rw4CrgTBH5yX2MEpEJIjLBbXMZkOpew/knME5VrcvMmCAkJUGfPs4zOGVpZv20lbOf+pLMmG30lW40XnIKf7uuIQ8+GNlYTeURTJdaU1V9X0QmA6hqkYh4WsRTVf+HU5W6vDb/Bv7tZRzGVEal12gAUlMhVw+R3S2VhWk76d+uIY9d1pduLerBw5GN01Q+wSScgyLSBKecDSIyFMjyNCpjjGdmzy59pdTtt5mZ+WnUWl/CHRf04LrhHYm1sjTGI8EknFuB2UBnEfkaaIbTnWWMqYBGj4a0jIM0GZlMrfZ7aUETPvhzX+Kb1Il0aKaSC6aW2jIROR3ojtPNtVpVCz2PzBgTckXFJbQ/bwPxxWsoLoxhiPTh3YfaWVkaExbBzPh5qd+qbiKSBaSo6k5vwjLGhNqqbdlMnJ5MckYW5/RswQOX9KZF/VqRDstUIcF0qd0AnAx87i6PABbjJJ77VPUNj2IzxoRAflEx//l8Pc9+vo4Gtavz7ytP4oI+rX5xVuNbXcBGphkvBJNwSoAeqroDnPtygOdw5slZBFjCMSaCyksUyzbvY+K0ZNbuPMCYk9pw14U9aRRXI+A+fEeugSUdE3rBJJwOpcnGtRPopqp7RcSu5RgTIUlJ8OKLsGuXs+ybKHILinhy/hpe/noDLevX4pVrB3FGQvMj3uubpA6PXHPMnm0Jx4ReMAnnKxGZA3zgLo8FFolIHLDfq8CMMWXzPSPx9eSTsF13s6ZhMlv2HuK3Q+OZODKBem6xzbKS1OjRh1+XLhsTasEknJtwksxwnFFqrwPT3bv6bTZQYyLA/4wEQGoWEnfGKhaWbKHe3jjeGz+U2S81YdidhxNIoCQ1ezakpBx+bddwjFeCGRatwDT3YYyJAv5nJHW6bafR2anExhWQtbgzuau7cutXsXzzjbM9NRWaNSt7X+AkGUs0xkvBDot+FGiOc4YjOHmovsexGWPKUJoYPpyXT5NzV7BZt1Gwsx67pg+iYEcDAHZtL38fzZrBjTdakjHhE0yX2mPARaq6yutgjDGBJSXBP/4BBQXQoAGA0nf0VqpfvJKM3GL2LepG9nedoaTserw33ug8W7eZiZRgEs4OSzbGhF/pSLL69fm5awwgq/AQTc5LYUOLXRSsa0jxt33JXlev3H0NG3Y4wViiMZESTMJZKiLv4cyomV+6MtCkaMaY0Ag8Ck2p238zjUasAoG9n/UkZ1kH0CNv4CztKgM7mzHRJZiEUx/IBc71WaeAJRxjPOI/Cq1aowM0OT+FWu32cmhjU/bO60NR1pHFNmvXhr/85cjkYonGRJNgRqldF45AjDGHZWa6L6SE+oM20OCUNWhxDLvn9uVgSlsCTRd16FBYQzTmmAUzSq0WTj21XsDPlf5U9XoP4zKmyiotcVa9WTZNRi2nZstscle3YO+C3hQfLL/YplUIMNEsmCmm3wBaAucBXwJtgRwvgzKmKklKgvbtoXp1N9nEFtPw1NW0uuZ/VKubz66ZA9j1YeIvks2wYc7Dl1UIMNEsmGs4XVT1chG5WFVfE5G3gU+9DkxERgLPALHAVFV9xG+7uNtH4VxjulZVl3kdlzGh5D84oGabvTQemUKNpgc4kNKWff/tQUneL4ttAqxdCzt3WpVnU3EEk3BKC3TuF5HewHagg2cRASISC/wHOAfIAJaIyGxVXenT7Hygq/sYwuEK1sZUCL6zA0j1Ihqetpp6AzdSnF2bHe8PIm9D87Lf7MMqBJiKIpgutSki0gi4E2eq6ZU4N4N6aTCwTlXTVbUAeBe42K/NxcDr6lgMNBSRVh7HZcwJS0o6MtnU6rCL1jcson7iRnKWtSfz5dOCSjalQ5+NqSiCGaU21X35JdDJ23B+1gbY4rOcwS/PXgK1aQNs820kIuOB8QDx8fEhD9SYY+HbhRZTs5BGZ62kbp8MCvfEsf3Nk8nf2rjc99etCx06WPeZqZiCGaVWE6dadAff9qp6n3dhBRjz6dz7c6xtUNUpwBSAxMTEX2w3JhySkuDNN2HzZme5drdtND5nBbF1Csj6tjP7v+4KxbHl7mPgQFi69PD++vSxxGMqlmCu4cwCsoAf8Kk04LEMoJ3Pclsg8zjaGBNxw4cfLk0TE5dH43NWENd9OwU76rNr2uFim+WJiYG77nJe+8/O+cUXkJ1tycdEv2ASTltVHel5JEdaAnQVkY7AVmAccKVfm9nAzSLyLk53W5aqbsOYY+TlKK+kpNJko8T1zqDRmauIqV7Mvi+7k/19p3KLbfoqKYHvvnNi9K9C4DsFAVjSMdErmH/t34hIH88j8aGqRcDNOMOvVwHvq+oKEZkgIhPcZnOBdGAd8CLwx3DGaCqH0rOF1FTnOSmp/LZ9+hxu47/sa/ZsSEx09hlbP5fmv/qephckU7inLpmvnEr24i5BJxvffUL599oEmpjNmGhR5hmOiKTgXBOpBlwnIuk4XWql8+H09TIwVZ2Lk1R81z3v81pxZiM15rj5f0GXdad+oG6sQGcWw4fD999DcTGoKvUGbKTh6asB2LugFznL2hP48uNhw4ZB69YwZw7k5R1e7ztRWmms/pWk7cZPE83K61K7MGxRGBMh/jNnlvWF7Z+Yfvjhl9t9k1C1xgdocn4ytdru41B6M/Z82pvi7COLbfqqVg0SEn7ZrVdWd5/vvTd246epKMQ5USingchQYIWq5rjL9YCeqvpdGOILqcTERF1aOszHGFcwX9j+FQGGDTvyzOL22+HJJyG/sIT6g9NpOHwtJYWx7PtvTw6mtuFoZzW3327JwkQvEflBVRNPeD9BJJwfgQFuFxYiEgMsVdUBJ/rh4WYJx5wI/8Tkvzzo3CwyWidTs2U2B9NasvezXpSUU2xz4EDIz7czExP9wplwflLV/n7rkr2+huMFSzjGC3mFxfxz4Vqe+yKdogM12LOgF4fWOEUvataEiy6Cb7+FrVudCgOqUKcOvPOOXXMxFUOoEk4ww2TSReQWEanuPv6EMzrMmCpv6ca9jPrnVzz7xXpK0tuQOfX0n5MNwPvvQ7duTrIBJ9kMHGjJxlRNwSScCcAwnPthSkvMjPcyKGOi3YH8Iu6elcrlL3xLQVEJr18/mGbp/SjJr/5zm86dA983k5rq3FNjTFUTTC21nTg3XhpjgC/X7OL2GSlkZh3impM7kLu4OzeNrUZCAixfDkVFzqizp55y2tevf+T78/MPD0CwazemKgmm0oAxBtifW8B9c1YyY9lWOjeLY9qEk5kxpTGPP+xsT02Fyy6DFi3g3HMPd5llZwfen83OaaoaSzjGBGFuyjbumpXK/txCbj6jCzef2YVa1WO51q+7LC0NPvjgyHX+9/r4rjemKrGEY0w5dmbncdesFcxbsZ3eberz2vWD6dX6cLHNhISj3zjqXxnACm2aqqrMYdEicmt5b1TVpzyJyEM2LNoES1X54IcMHpizkryiEv5ydjduPLUj1WIPj7OZPRuuuAJyc53lhARYtSpCARvjoVANiy7vDKee+9wdGIRTnRngImDRiX6wMdFqy95cbp+ZwldrdzO4Q2MeGduHTs3q/qLd/PmHkw043WlJSXbmYkxZyhwWrar3quq9QFOcSgN/VdW/AgNx5p4xplIpLlFe+XoD5z29iGWb9nH/xb14d/zQgMkGnIEB4lexxqo1G1O2YK7hxAMFPssFOLN/GlNprNuZw23Tklm2eT8jujfjwTF9aNOwdrnvGT0axo6FadOOXGeMCSyYhPMG8L2IzMSZrmAM8LqnURkTJoXFJbzw5Xr+uXAddWrG8tSv+jHmpDaI/6lLGT74wKo1GxOso9ZSAxCRAcCp7uIiVf3R06g8YoMGjK+UjCz+Pm05adtzuKBvK+4d3YumdWtGOixjok44Bg34qgNkq+orItJMRDqq6oYT/XBjIiGvsJinP1vLi1+l0ySuBi9cNZDzerWMdFjGVHpHTTgicjeQiDNa7RWgOvAmMNzb0IwJve837GXi9GQ27D7IrxPbcfsFPWhQu/rR32iMOWHBnOGMAU4ClgGoaqY7CZsnRORxnKHXBcB64DpV3R+g3UYgBygGikJxumcqr5y8Qh6bt5o3Fm+iXePavPW7IQzv0jTSYRlTpQSTcApUVUWkdAK2OI9jWgBMVtUiEXkUmAxMLKPtGaq62+N4TAX3edpOkmamsC07jxtO6chfz+1GnRpWZMOYcAvmr+59EXkBaCgiNwLXA1O9CkhV5/ssLgYu8+qzTOW292AB989Zycwft9K1eV2m/2EYA+IbRTosY6qsYKYneEJEzgGyca7j3KWqCzyPzHE98F5ZoQHz3TOvF1R1SqBGIjIed/6e+Ph4T4I00UVV+ThlG3fPWkHWoUJuOasrN53RmZrVYiMdmjFVWjCDBh5V1Yk4XV3+646LiHwGBBoWlKSqs9w2SUAR8FYZuxnuXk9qDiwQkTRV/UXJHTcRTQFnWPTxxmwqhh3ZedzxYSoLVu6gb9sGvPm7IfRoVf/obzTGeC6YLrVz+OU1lPMDrAuaqp5d3nYRuQa4EDhLy7hRSFUz3eed7k2pg7Eab1WWqvLeki08OHcVBUUlJI3qwXXDOxxRbNMYE1llJhwR+QPwR6CziCT7bKoHfONVQCIyEieZna6quWW0iQNiVDXHfX0ucJ9XMZnotmnPQSbPSOGb9XsY0rExj47tS4emXo9tMcYcq/LOcN4GPgEeBib5rM9R1b0exvRvoCZONxnAYlWdICKtgamqOgpoAcx0t1cD3lbVeR7GZKJQabHNJ+avplpMDA+N6cO4Qe2IiQmuLI0xJrzKTDiqmgVkicgzwF5VzQEQkXoiMkRVv/MiIFXtUsb6TGCU+zod6OfF55uKYfX2HG6bnszyLfs5K6E5D4zpTasG5RfbNMZEVjDXcJ4DBvgsHwywzpiwKCgq4dkv1vGfz9dRr1Z1nhnXn9H9WgddbNMYEznBJBzxvXCvqiUiYnfNmbBbvmU/t01LZvWOHC7u35q7LuxJEyu2aUyFEUziSBeRW3DOasAZSJDuXUjGHOlQQTFPLVjNS//bQPN6tZh6dSJn92wR6bCMMccomIQzAfgncAfOzZYLcW+kNMZr367fw6QZyWzak8uVQ+KZdH4C9WtZsU1jKqJgKg3sBMaFIRZjfpadV8jDc9N45/vNtG9Sh7dvHMKwzlZs05iKrLz7cG5T1cdE5F84ZzZHUNVbPI3MVFkLV+0gaWYqO3PyuPHUjtx6Tndq17CyNMZUdOWd4axyn22KTBMWew7kc+9HK5m9PJPuLerx/FUD6d+uYaTDMsaESHn34XzkPr8WvnBMVaSqzF6eyb0frSQnr5C/nN2NP4zoTI1qVpbGmMqkvC61jwjQlVZKVUd7EpGpUrZlHeLOD1P5bNVO+rdryGOX9aVbC8/m9zPGRFB5XWpPuM+X4lR2ftNdvgLY6GFMpgooKVHeXbKFh+euorCkhDsu6MF1wzsSa2VpjKm0yutS+xJARO5X1dN8Nn0kIlaV2Ry3jbsPMmlGMovT9zKscxMeubQv8U3qRDosY4zHgrkPp5mIdHLrlyEiHYFm3oZlKqOi4hJe/noDT85fQ43YGB65tA+/HtTOytIYU0UEk3D+AnwhIqXVBToAv/csIlMppW3PZuK0ZJZnZHF2jxY8cElvWjaoFemwjDFhFMyNn/NEpCuQ4K5KU9V8b8MylUV+UTHPfr6eZ79YR/1a1fn3lSdxQZ9WdlZjTBUUzBTTdYBbgfaqeqOIdBWR7qo6x/vwTEX24+Z9TJyezJodBxhzUhvuurAnjeJqRDosY0yEBNOl9grwA3Cyu5wBfABYwjEB5RYU8eT8Nbz89QZa1q/FK9cO4oyE5pEOyxgTYcEknM6q+msRuQJAVQ+J9YeYMnyzbjeTZqSweW8uvx0az8SRCdSzYpvGGIJLOAUiUhv3JlAR6QzYNRxzhKxDhTw8dxXvLtlCx6ZxvDd+KEM6NYl0WMaYKBJMwrkbmAe0E5G3gOHAtV4FJCL3ADcCu9xVt6vq3ADtRgLPALHAVFV9xKuYTPnmr9jOHR+msudgARNO78yfz+5KrepWbNMYc6RyE46IxACNcKoNDAUE+JOq7vY4rn+o6hNlbRSRWOA/wDk415SWiMhsVV3pcVzGx+4D+dwzewVzkreR0LIeL10ziD5tG0Q6LGNMlCo34bjTSd+squ8DH4cppmAMBtb53Iz6LnAxYAknDFSVD3/ayr0frSQ3v5i/ntONCSM6Uz3Wim0aY8oWTJfaAhH5G/AecLB0paru9SwquFlErsaZGuGvqrrPb3sbYIvPcgYwJNCORGQ87gyl8fHxHoRatWTuP0TSzBQ+X72LAfENeXRsX7pasU1jTBCCSTjXu883+axToNPxfqiIfIZTENRfEvAccL/7GfcDT/rE8PMuArw3YGVrVZ0CTAFITEwss/q1KV9JifLW95t59JM0ikuUuy/qydUnd7Bim8aYoAVTaaBjqD9UVc8Opp2IvEjg+30ygHY+y22BzBCEZgJI33WASdNT+H7jXk7p0pSHL+1Du8ZWbNMYc2yCqTRQC/gjcArOWcRXwPOqmudFQCLSSlW3uYtjgNQAzZYAXd1ColuBccCVXsRTlRUVlzDlq3Se/mwttarF8Nhlfbl8YFsrS2OMOS7BdKm9DuQA/3KXrwDeAC73KKbHRKQ/TnLbiFsoVERa4wx/HqWqRSJyM/ApzrDol1V1hUfxVEkrMrOYOD2Z1K3ZnNerBfdf3Jvm9a3YpjHm+AWTcLqraj+f5c9FZLlXAanqVWWszwRG+SzPBX5xf445MXmFxfzrv2t5/st0GtWpwXO/GcD5fVpFOixjTCUQTML5UUSGqupiABEZAnztbVgmEn7YtJfbpiWzftdBxg5oy50X9qBhHSu2aYwJjWASzhDgahHZ7C7HA6tEJAVQVe3rWXQmLA7mF/H4p6t57duNtG5Qm9euH8zp3WyOPWNMaAWTcEZ6HoWJmEVrdjF5RgqZWYe4emh7/j4ygbo1g/lnYYwxxyaYYdGbwhGICa+s3ELu/3gl037IoFOzON7//ckM6tA40mEZYyox+69sFTQvdRt3zlrB3oMF3HRGZ/7vTCu2aYzxniWcKmRnTh53z1rBJ6nb6dW6Pq9eN4hera3YpjEmPCzhVAGqyvRlW7l/zkoOFRZz28ju3HhqJyu2aYwJK0s4lVzGvlxun5nKojW7GNShEY+M7UvnZnUjHZYxpgqyhFNJlZQobyzexKPz0gC47+Je/HZIe2Ks2KYxJkIs4VRC63YeYNL0ZJZu2sdp3Zrx0JjetG1kxTaNMZFlCacSKSwuYcqidJ5ZuJba1WN54vJ+jB3QxoptGmOigiWcSiJ1axa3TUtm5bZsRvVpyT2je9G8nhXbNMZED0s4FVxeYTHPLFzLlEXpNI6rwfO/HcDI3lZs0xgTfSzhVGBLNu5l4rRk0ncf5PKBbbnjgp40qFM90mEZY0xAlnAqoAP5RTw2L43Xv91E20a1eeOGwZza1YptGmOimyWcCuaL1TtJmplKZtYhrh3Wgb+f1504K7ZpjKkA7Juqgth3sID7P17JjGVb6dK8LtMmDGNg+0aRDssYY4JmCSfKqSpzU7Zz9+xU9ucWcsuZXbjpzC7UrGbFNo0xFUvUJRwReQ/o7i42BParav8A7TYCOUAxUKSqiWEKMWx2Zudxx4epzF+5gz5tGvD69UPo2bp+pMMyxpjjEnUJR1V/XfpaRJ4Essppfoaq7vY+qvBSVT5YmsH9H6+koKiEyecncMMpHalmxTaNMRVY1CWcUuLcHv8r4MxIxxJOW/bmMnlGCv9bt5vBHRvzyKV96GTFNo0xlUDUJhzgVGCHqq4tY7sC80VEgRdUdUr4Qgu94hLltW828vinq4mNER64pDdXDo63YpvGmEojIglHRD4DWgbYlKSqs9zXVwDvlLOb4aqaKSLNgQUikqaqiwJ81nhgPEB8fPwJRu6NtTtymDg9mWWb93NG92Y8OKYPrRvWjnRYxhgTUqKqkY7hF0SkGrAVGKiqGUG0vwc4oKpPlNcuMTFRly5dGpogQ6CwuITnv1jPv/67jriasdx9US8u7t/aim0aY6KKiPwQioFZ0dqldjaQVlayEZE4IEZVc9zX5wL3hTPAE5WcsZ/bpiWTtj2HC/u24p7RvWhat2akwzLGGM9Ea8IZh193moi0Bqaq6iigBTDTPROoBrytqvPCHuVxyCss5h8L1vDiV+k0q1eTF69O5JyeLSIdljHGeC4qE46qXhtgXSYwyn2dDvQLc1gnbHH6HibPSGHD7oOMG9SOyaN60KC2Fds0xlQNUZlwKpucvEIe+SSNt77bTLvGtXnrd0MY3qVppMMyxpiwsoTjsc/TdnL7zBR2ZOfxu1M6cuu53ahTww67MabqsW8+j+w9WMB9H63gw58y6dq8Ls/+YRgnxVuxTWNM1WUJJ8RUlTnJ27hn9gqyDhXyp7O68sczOluxTWNMlWcJJ4R2ZOeRNDOVz1btoF/bBrx14xASWlqxTWOMAUs4IaGqvLdkCw/OXUVhcQlJo3pw/SkdibWyNMYY8zNLOCdo056DTJqewrfpexjaqTGPXNqXDk3jIh2WMcZEHUs4x6m4RHnl6w08MX811WNieGhMH8YNamfFNo0xpgyWcI7D6u053DY9meVb9nNWQnMeGNObVg2s2KYxxpTHEs4xKCgq4dkv1vGfz9dRr1Z1nhnXn9H9rNimMcYEwxJOkH7asp+J05JZvSOHi/u35u6LetE4rkakwzLGmArDEs5RHCoo5qkFq3npfxtoXq8WL12TyFk9rNimMcYcK0s45fhm/W4mTU9h895crhwSz6TzE6hfy4ptGmPM8bCEE0B2XiEPz03jne83075JHd65cSgnd24S6bCMMaZCs4Tj57OVO0j6MIVdOfn8/rRO/PnsbtSuYWVpjDHmRFnCce05kM+9H61k9vJMElrW48WrE+nbtmGkwzLGmEqjyiccVWX28kzumb2CA/lF3HpONyac3pka1WIiHZoxxlQqVTrhbMs6xB0zU1mYtpP+7Rry2GV96daiXqTDMsaYSqlKJpySEuWdJZt5eG4axSXKnRf25NphHazYpjHGeCgi/UYicrmIrBCREhFJ9Ns2WUTWichqETmvjPc3FpEFIrLWfQ56ZrONuw9y5dTFJM1MpV+7Bnz659O4wSo7G2OM5yJ1oSIVuBRY5LtSRHoC44BewEjgWREJNERsErBQVbsCC93lo9p1IJ/znl7EisxsHh3bhzdvGEJ8kzon8nMYY4wJUkS61FR1FRCoBtnFwLuqmg9sEJF1wGDg2wDtRrivXwO+ACYe7XO3Z+VxdbdmPHBJb1rUr3Xc8RtjjDl20XYNpw2w2Gc5w13nr4WqbgNQ1W0i0rysHYrIeGC8u5g/9ZpBqVNDFa13mgK7Ix1EECzO0KkIMYLFGWoVJc7uodiJZwlHRD4DWgbYlKSqs8p6W4B1eiJxqOoUYIob01JVTTzKWyLO4gytihBnRYgRLM5Qq0hxhmI/niUcVT37ON6WAbTzWW4LZAZot0NEWrlnN62AnccTozHGmPCJtrsbZwPjRKSmiHQEugLfl9HuGvf1NUBZZ0zGGGOiRKSGRY8RkQzgZOBjEfkUQFVXAO8DK4F5wE2qWuy+Z6rPEOpHgHNEZC1wjrscjCkh/DG8ZHGGVkWIsyLECBZnqFWpOEX1hC6RGGOMMUGJti41Y4wxlZQlHGOMMWFR6RJOJMvmnEDM74nIT+5jo4j8VEa7jSKS4rYLyTDFYyEi94jIVp9YR5XRbqR7jNeJSFBVIEIY4+MikiYiySIyU0QaltEuIsfyaMdGHP90tyeLyIBwxeYTQzsR+VxEVrl/S38K0GaEiGT5/Fu4K9xxunGU+3uMkuPZ3ec4/SQi2SLyZ782ETmeIvKyiOwUkVSfdUF9Bx7X37mqVqoH0APnJqUvgESf9T2B5UBNoCOwHogN8P7HgEnu60nAo2GO/0ngrjK2bQSaRvDY3gP87ShtYt1j2wmo4R7znmGM8Vygmvv60bJ+f5E4lsEcG2AU8AnOPWlDge8i8HtuBQxwX9cD1gSIcwQwJ9yxHevvMRqOZ4B/A9uB9tFwPIHTgAFAqs+6o34HHu/feaU7w1HVVaq6OsCmn8vmqOoGoLRsTqB2r7mvXwMu8STQAMSp9fMr4J1wfaYHBgPrVDVdVQuAd3GOaVio6nxVLXIXF+PcyxUtgjk2FwOvq2Mx0NC91yxsVHWbqi5zX+cAqwhc8aMiiPjx9HMWsF5VN0Uwhp+p6iJgr9/qYL4Dj+vvvNIlnHK0Abb4LAdVNgcos2yOB04Fdqjq2jK2KzBfRH5wS/ZEws1u18TLZZxqB3ucw+F6nP/dBhKJYxnMsYmm44eIdABOAr4LsPlkEVkuIp+ISK/wRvazo/0eo+p44hQnLus/lNFwPCG478DjOq7RVkstKBIlZXOORZAxX0H5ZzfDVTVTnNpxC0Qkzf0fSljiBJ4D7sc5bvfjdP9d77+LAO8N6XEO5liKSBJQBLxVxm48P5YBBHNsIvrv1JeI1AWmA39W1Wy/zctwuoUOuNfyPsS5UTvcjvZ7jKbjWQMYDUwOsDlajmewjuu4VsiEoxWwbM7RYhaRajhTNgwsZx+Z7vNOEZmJc1ob0i/JYI+tiLwIzAmwKdjjfNyCOJbXABcCZ6nb4RxgH54fywCCOTaeH79giEh1nGTzlqrO8N/um4BUda6IPCsiTVU1rIUog/g9RsXxdJ0PLFPVHf4bouV4uoL5Djyu41qVutSivWzO2UCaqmYE2igicSJSr/Q1zsXx1EBtveLX9z2mjM9fAnQVkY7u/+jG4RzTsBCRkThTVYxW1dwy2kTqWAZzbGYDV7ujq4YCWaXdG+HiXkt8CVilqk+V0aal2w4RGYzzXbInfFEG/XuM+PH0UWYPRjQcTx/BfAce3995uEdFeP3A+SLMAPKBHcCnPtuScEZWrAbO91k/FXdEG9AEZ1K3te5z4zDF/SowwW9da2Cu+7oTzkiQ5cAKnO6jcB/bN4AUINn9x9XKP053eRTOyKb14Y4TZzDIFuAn9/F8NB3LQMcGmFD6u8fpqviPuz0Fn5GWYYzxFJzukWSf4zjKL86b3WO3HGdwxrAIxBnw9xhtx9ONow5OAmngsy7ixxMnAW4DCt3vzRvK+g4Mxd+5lbYxxhgTFlWpS80YY0wEWcIxxhgTFpZwjDHGhIUlHGOMMWFhCccYY0xYWMIxlZ5biXfYCe7jwDG0fVVELjuRzwsVEfnmGNtHTeym8rGEY6qCEcAJJZyKSlWr5M9topMlHFMhiciHbsHGFb5FG905Opa5RRAXusUnJwB/cecZOdX/f/GlZy8iUtd9zzJx5lg5avVbEbnaLWa6XETe8Nl0moh8IyLppZ9V1v5FpIM4c8+86P4880WktrttkLv/b8WZ6yfVXR/rLi9xt/++jPhKf7YRIvKFiEwTZ76gt0rvbC/nZztLRH50Y31ZRGq66x8RkZXu5z7hrrtcRFLd4+B1iSBTUUXirlt72ONEHxy++7k2TjmTJkAznCoDHf3a3IPPPD44VR0u81k+4D5XA+q7r5viVC0Q3zZ+MfTCqVrR1O/zXgU+wPkPXU+cMu5l7h/ogFNotL+77X3gt+7rVNy7zoFHcOctAcYDd7ivawJLS39uvxhLf7YRQBZOzasY4FvglADtXwUuA2q5x7Kbu/514M9AY/dnLj0uDd3nFKCN7zp72MP/YWc4pqK6RURKy4C0w6mNNxRYpM58R6iq/zwfRyPAQyKSDHyGU269RTntzwSmqVtg0e/zPlTVElVd6bOP8va/QVV/cl//AHQQZ7bSeqpaeh3mbZ/9n4tTI+wnnKkDmnD06sLfq2qGqpbglKvpUE7b7m5Ma9zl13Am68oG8oCpInIpUFqv7mvgVRG5EWdyLmN+oUJWizZVm4iMwCl2erKq5orIFzj/IxeCKz1fhNud7HYr1XDX/wbnLGmgqhaKyEZ3v2WGUs7n5fu1O9r+fdsX45y5ldflJcD/qeqn5bQpL6Ziyv/7D/jZqlrkFpc8C6dg483Amao6QUSGABcAP4lIf1WNVPFJE6XsDMdURA2AfW6yScA5swGnm+h0caqBIyKN3fU5OFMll9rI4WkgLgaq++x3p5sMzgDaHyWOhcCvRKSJ3+eVF3fQ+1fVfUCOW+UYnC/4Up8CfxBnGgFEpJtbMTlU0nDOsrq4y1cBX4ozR04DVZ2L08XW3/38zqr6nareBezmyNL1xgB2hmMqpnnABLdrajVOtxqqussdQDBDRGJw5vE4B/gImOZepP8/4EVgloh8j5M0Drr7fQv4SESW4nQ5pZUXhKquEJEHcb6Ii4EfgWvLecsx7d91A/CiiBwEvsC5DgNOhfMOwDL3LG0XIZwOXVXzROQ64ANx5mpaAjyPcw1nloiUnlH+xX3L4yLS1V23EKfqsTFHsGrRxkQxEamrqqUjzSbhTAnxpwiHZcxxsTMcY6LbBSIyGedvdRPln0EZE9XsDMcYY0xY2KABY4wxYWEJxxhjTFhYwjHGGBMWlnCMMcaEhSUcY4wxYfH/7fpNA1+2ybYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of edges in consider 12431\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "low_limit = -2.5\n",
    "up_limit = 2.5\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 10, color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on edges of Cora dataset perturb edges:' + str(batch_edges))\n",
    "# plt.title('Influence function on Complete Node of Citeseer dataset')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "# plt.title('Influence function on Iris dataset')\n",
    "plt.show()\n",
    "print('Number of edges in consider %d'% len(f_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "acf2c3bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([33])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extra_index_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "79444846",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.75212483, -0.04163444, -0.03761754, ..., -0.19991923,\n",
       "       -0.10448183,  0.28174796])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "46e70270",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.6304563637989626, pvalue=0.0)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "6f1331c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.5272169586664675, pvalue=0.0)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "180947ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([acctual_influence_1, predict_influence_1, f_l.numpy().astype(int), t_l.numpy().astype(int)]).T\n",
    "df.columns = ['acctual_influence', 'predict_influence', 'from_edges', 'to_edges']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "305e1865",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('result_data/' + data_set + '_edge_influence_002.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7780b98f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# df = pd.read_csv('result_data/pubmed_edge_influence.csv', header = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2aa2428d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# import numpy as np\n",
    "# low_limit = -10\n",
    "# up_limit = 10\n",
    "# x = np.linspace(low_limit, up_limit)\n",
    "# # x = np.linspace(-0.2, 0.15)\n",
    "# plt.plot(x, x)\n",
    "# plt.scatter(df[0], df[1], s = 10, color='blue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "406dade0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# g = dgl.graph(([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]))\n",
    "# g.edata['w'] = torch.arange(10).view(5, 2)\n",
    "# sg, inverse_indices = dgl.khop_in_subgraph(g, 0, k=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
