{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97458bb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import math\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import torch.nn.functional as F\n",
    "import tensorflow.compat.v1 as tf\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.utils import check_array\n",
    "from dataset import load_graph_dataset\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a00c4f26",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/zizhang/Desktop/Projects/Project6_influence_function/graph_influence_function/experiments'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0bf84684",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "32d092d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_set = 'cora'\n",
    "# # l2_term = 0.01\n",
    "# l2_term = 0.01\n",
    "# # batch_edges = 1\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'pubmed'\n",
    "# l2_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "data_set = 'citeseer'\n",
    "l2_term = 1\n",
    "num_layer = 2\n",
    "\n",
    "# data_set = 'reddit'\n",
    "# l2_term = 0.99\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ff867b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 3327\n",
      "  NumEdges: 9228\n",
      "  NumFeats: 3703\n",
      "  NumClasses: 6\n",
      "  NumTrainingSamples: 120\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "82d565de",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fd88881a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "val_y = labels[val_mask].numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d4a0be35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "# lr.fit(train_x, train_y, sample_weight=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "897df57e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask, seed_val=10):\n",
    "    torch.manual_seed(seed_val)\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        \n",
    "#         print(f_index)\n",
    "#         random_index = torch.randint(0, len(to_index_list), (1,))[0]\n",
    "        for to_index_e in to_index_list:\n",
    "#             print(to_index_list)\n",
    "#             print(to_index_e.item())\n",
    "#             j = to_index_list[to_index_e[0]]\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ecabbd84",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 120/120 [00:00<00:00, 19091.05it/s]\n"
     ]
    }
   ],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "f_l, t_l = from_indexes, to_indexes\n",
    "\n",
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# df = pd.read_csv('result_data/' + data_set + '_edge_influence.csv', header = None)\n",
    "# df = df.loc[df[0] != 0]\n",
    "# f_l = torch.tensor(df[2].values.astype(int))\n",
    "# t_l = torch.tensor(df[3].values.astype(int))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "273f86ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# split_side = int(len(f_l) / batch_edges)\n",
    "# total_edges = np.arange(len(f_l))\n",
    "# np.random.shuffle(total_edges)\n",
    "# new_index = np.array_split(total_edges, split_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c5fd08fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "changed_list = []\n",
    "changed_list_change = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "241d2b17",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 119/119 [00:02<00:00, 48.65it/s]\n"
     ]
    }
   ],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# train the original data\n",
    "# calculate the hessian matrix\n",
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "\n",
    "###### write out hessian matrix\n",
    "\n",
    "# hess = cp.asnumpy(hess)\n",
    "\n",
    "# pd.DataFrame(hess).to_csv(\"reddit/hess.csv\", index = False)\n",
    "\n",
    "######\n",
    "\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b7b84551",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # convert to one-hot labels\n",
    "# enc = OneHotEncoder(handle_unknown='ignore')\n",
    "# enc.fit(train_y.reshape(-1, 1))\n",
    "# one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# # train the original data\n",
    "# # calculate the hessian matrix\n",
    "# lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "# lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "# logits_test_y_origin = test_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_test_y_origin, one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# # numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y_origin, axis=1))\n",
    "\n",
    "# # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "# val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(test_x, \n",
    "#                                                                     logits_test_y_origin,\n",
    "#                                                                     one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "# hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "# loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "# del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c83a5a44",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 12431/12431 [24:12<00:00,  8.56it/s]\n"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "# for k in tqdm(range(len(train_node_idx))):\n",
    "for k in tqdm(range(len(f_l))):\n",
    "# for k in tqdm(range(len(new_index))):\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[new_index[k]], to_index=t_l[new_index[k]])\n",
    "#     eis.remove_edges_sgc_from_influence()\n",
    "    eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "    eis.remove_edges_sgc_from_influence()\n",
    "    feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "#     node_id = train_node_idx.numpy()[k]\n",
    "#     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "#     nis.remove_edges_sgc()\n",
    "#     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    if extra_index_train == []:\n",
    "        predict_influence_1.append(0.0)\n",
    "        acctual_influence_1.append(0.0)\n",
    "        continue\n",
    "    \n",
    "    feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[perturb_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "#     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "#     weight_1 = np.ones(len(train_x_orig))\n",
    "#     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "#     weight_2 = np.ones(len(train_x_orig))\n",
    "#     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "#     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "#     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "#     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "#     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "#     assert(np.allclose(train_x_delete_1, train_x))\n",
    "#     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "#     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "#     logits_val_y_new_1 = val_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "#     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_val_y_new_1, one_hot_labels_val)\n",
    "    \n",
    "    \n",
    "    \n",
    "    lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "    train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "    train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "    lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "    logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "    new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "    acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddab4990",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1af810e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# # for k in tqdm(range(len(train_node_idx))):\n",
    "# for k in tqdm(range(len(f_l))):\n",
    "\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "#     eis.remove_edges_sgc()\n",
    "#     feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "# #     node_id = train_node_idx.numpy()[k]\n",
    "# #     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "# #     nis.remove_edges_sgc()\n",
    "# #     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "#     extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "#     extra_index_train = torch.tensor(\n",
    "#         [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "#     extra_index_train_in_train = [\n",
    "#         np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "#     if extra_index_train != []:\n",
    "#         predict_influence_1.append(0.0)\n",
    "#         acctual_influence_1.append(0.0)\n",
    "#         continue\n",
    "    \n",
    "#     feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "#     perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "#     train_x_new = feat_to_be_added\n",
    "#     train_y_new = train_y[perturb_index]\n",
    "    \n",
    "#     train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "#     train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "#     one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "#     logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "#     train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "#                                             logits_train_y_origin_0, \n",
    "#                                             one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "#     pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "# #     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "# #     weight_1 = np.ones(len(train_x_orig))\n",
    "# #     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "# #     weight_2 = np.ones(len(train_x_orig))\n",
    "# #     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "# #     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "#     weight_3 = np.ones(len(train_x_orig))\n",
    "#     weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "# #     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "# #     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "# #     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "# #     assert(np.allclose(train_x_delete_1, train_x))\n",
    "# #     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "# #     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "# #     logits_test_y_new_1 = test_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "# #     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_test_y_new_1, one_hot_labels_test)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "#     train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "#     train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "#     lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "#     logits_test_y_new_2 = test_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "#     new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_test_y_new_2, one_hot_labels_test, l2_reg = True)\n",
    "    \n",
    "#     predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "#     acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "df4914f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAApj0lEQVR4nO3dd3xVhf3/8deHEfbeK4S9I7IFq6ioKC5AW2vdVaq1P1ttFRRU3Dg6bK1V3KtWZYgIImpR3AslCSFg2DsMScIIGffz++NevlKEcIHcnCT3/Xw88rj3nHvvOe8cyPncsz7H3B0REYk/lYIOICIiwVABEBGJUyoAIiJxSgVARCROqQCIiMSpKkEHOByNGzf2pKSkoGOIiJR5+YUh1m3fzY49heRvzNzi7k32f0+5KgBJSUl8/fXXQccQESmzikLOC5+t5KF3ltAcGHtGVy4b3G7Vgd5brgqAiIgcXGZWLmOnpvLNqh84sXMT7hvVi1b1a3DZQd6vAiAiUs4VFIV44sNl/P39TGpWq8xffn4MI49thZkV+zkVABGRcixtXTY3TUlh8YYcRiS3YOLZPWhSp1pUn1UBEBEph/IKivjbe9/z5EfLaVQrgScu6cvpPZof1jRUAEREypkvlm9l3LRUVmzZyS/6teHWEd2oV6PqYU9HBUBEpJzIzSvgwTlLePHzVbRpWIOXrxrIkI6Nj3h6KgAiIuXAvCVZjJ+WyoacPK4c0o4/nd6ZmglHtwpXARARKcN+2JnP3W+lM+3bdXRqWpup1w6mT2KDEpm2CoCISBnk7sxK3cAdMxaRvbuA60/pxHUndaBalcolNg8VABGRMmZTTh63vZHG3PRNJLeux0tXDaRbi7olPh8VABGRMsLdee3rNdwzazH5hSFuPbMrVw5pR5XKsenbqQIgIlIGrN66i1ump/BJ5lYGtmvIA6OTSWpcK6bzVAEQEQlQUch57tOVPPzOEipXMu45rycXDUikUqXi2ziUBBUAEZGALN2Uy81TUvhuzXZO7tqUe0f2pEW9GqU2fxUAEZFSll8Y4vEPl/GP/35P7WpVeOTC3pxzTMtDNm8raSoAIiKlaOGa7YydmkLGxlzOOaYld5zdnUa1o2veVtJUAERESsHu/CL+9t5SnvxoOU3qVOOpS/sxrHuzQDMFVgDMrA3wAtAcCAGT3f2RoPKIiMTKZ8u2csu0FFZu3cUvByRyy5ldqVv98Ju3lbQgtwAKgT+6+wIzqwN8Y2bvunt6gJlEREpMTl4Bk97O4N9frKZto5r8++qBDO5w5M3bSlpgBcDdNwAbIs9zzWwx0ApQARCRcu+/GZu4dVoaWbl5XP2zdtx4ahdqJJRcG4eSUCaOAZhZEnAs8MUBXhsDjAFITEws3WAiIodp64493PVWOjO+W0+XZnV4/JK+9G5TP+hYBxR4ATCz2sBU4A/unrP/6+4+GZgM0K9fPy/leCIiUXF33ly4njtnppObV8DvT+nEdSd1JKFKbNo4lIRAC4CZVSW88n/Z3acFmUVE5EhtyN7NhOlpvJ+RxTFt6vPg6GS6NK8TdKxDCvIsIAOeBha7+1+CyiEicqRCIec/X63h/tmLKQiFmDCiG1cMaUflUmjjUBKC3AIYAlwCpJrZd5Fxt7r77OAiiYhEZ+WWnYyblsLny7dxXPtGTBrdi7aNYtu8raQFeRbQx0D5KJMiIhGFRSGe+WQFf567lITKlbh/VC8u7N+m1Ns4lITADwKLiJQXGRtzGDslhYVrsxnWrSn3nNeL5vWqBx3riKkAiIgcwp7CIv45bxmPzcukXo2q/OOXx3JWcoty+a1/XyoAIiLF+Hb1D4ydmsLSTTsYeWwrbjurOw1rJQQdq0SoAIiIHMCu/EL+PHcpz3yyguZ1q/PM5f04uWuwzdtKmgqAiMh+Ps3cwrhpqazetouLByUydnhX6pSB5m0lTQVARCQie3cB989ezH++WkO7xrV4dcwgBrZvFHSsmFEBEBEB5i7ayIQ30tiyYw+/ObE9NwzrTPWqZat5W0lTARCRuLZlxx4mvrmIt1I20LV5HZ66rB/JresHHatUqACISFxyd974bh13zkxn154ibjy1M9ec2KFMN28raSoAIhJ31m/fzfjpqcxbspljE8PN2zo1K/vN20qaCoCIxI1QyHn5y9U88HYGRSHn9rO6c9ngpHLTvK2kqQCISFxYsWUnY6em8OWKbRzfsTH3j+pFm4Y1g44VKBUAEanQCotCPPXxCv767lISqlTiwdHJXNCvdblv41ASVABEpMJKX5/DzVMXkrYuh9N7NOPuc3vStG75bd5W0lQARKTC2VNYxKP/zeRfHyyjfs2qPParPpzRs7m+9e9HBUBEKpRvVoWbt2Vm7WBUn1bcNqI7DSpI87aSpgIgIhXCzj2FPDx3Cc99upKW9Wrw3BX9GdqladCxyjQVABEp9z76fjO3TEtl7Q+7uWRQW8ae0ZXa1bR6OxQtIREpt7J3FXDv7HRe+3ot7RvX4rXfHMeAdg2DjlVuqACISLk0J20jt81IY9vOfH47tAPXn9KpwjdvK2kqACJSrmzODTdvm5W6ge4t6vLs5f3p2ape0LHKJRUAESkX3J1pC9Zx11vp7M4v4qbTuzDmhPZUrRw/zdtK2iELgJldAMxx91wzmwD0Ae5x9wUxTyciAqz9YRe3Tk9j/tLN9G3bgAdGJ9Oxae2gY5V70WwB3Obur5vZ8cDpwMPAv4CBMU0mInEvFHJe+mIVD7ydgQN3ntODSwa1pVKcNm8radEUgKLI4wjgX+4+w8wmxi6SiAgs27yDcVNT+GrlD/ysU2PuG6nmbSUtmgKwzsyeAIYBD5hZNUA73UQkJgqKQjz50XL+9t731KhamYcvOIbRfVqpjUMMRFMAfg4MBx529+1m1gK4KbaxRCQepa3LZuzUFBatz+HMXs2ZeE4PmtZR87ZYiaYAtABmufseMxsKJAMvxDKUiMSXvIIi/v7+9zwxfzkNaibw+MV9GN6zRdCxKrxoCsBUoJ+ZdQSeBt4E/g2cGctgIhIfvlq5jbFTU1i+eScX9G3NhBHdqVezatCx4kI0BSDk7oVmNgr4m7v/w8y+jXUwEanYduwp5KE5Gbzw+Spa1qvBC1cO4ITOTYKOFVeiKQAFZvZL4FLg7Mg4lWcROWIfLt3MrdNSWZ+9m8uOS+Km07tQS83bSl00S/wK4BrgXndfYWbtgJdiG0tEKqLtu/K56610pi1YR4cmtZhyzXH0bavmbUE5ZAFw93Qz+xPQ2cx6AkvcfVLso4lIRTI7dQO3z0hj+64CfndSR353ckc1bwtYNK0ghgLPAysBA9qY2WXuPj+myUSkQsjKyeP2GYuYs2gjPVvV5fkrB9CjpZq3lQXR7AL6M3Cauy8BMLPOwCtA36OduZk9A5wFZLl7z6OdnoiUHe7O69+s5Z630tlTGGLcGV256vh2VFHztjIjmgJQde/KH8Ddl5pZSR0Efg54FF1XIFKhrNm2i1umpfJx5hYGJDVk0uhetG+i5m1lTTQF4Gszexp4MTL8K+Cbkpi5u883s6SSmJaIBK8o5Lzw2UoenLOESgZ3n9eTXw1IVPO2MiqaAnAtcB1wPeFjAPOBx2IZSkTKn8ysXG6eksKC1dsZ2qUJ947sRav6NYKOJcWI5iygPcBfIj+lzszGAGMAEhMTg4ggIsUoKArxxIfL+Pv7mdSsVpm//uIYzuut5m3lwUELgJmlAn6w1909OSaJfjqfycBkgH79+h00j4iUvtS12dw0ZSEZG3MZkdyCO8/pQePa1YKOJVEqbgvgrFJLISLlSl5BEX9773ue/Gg5jWolMPmSvpzWo3nQseQwHbQAuPuqWM/czF4BhgKNzWwtcIe7Px3r+YrIkfti+VbGTUtlxZadXNi/Dbec2Y16NdQdpjwKtPmGu/8yyPmLSPRy8wp4YE4GL32+mjYNa/DyVQMZ0rFx0LHkKKj7kogc0ryMLMZPT2VDTh6/Pr4dfzytMzUTtPoo7/QvKCIHtW1nPne/lc70b9fRqWltpl47mD6JDYKOJSUkml5AQ4CJQNvI+w1wd28f22giEhR3Z1bqBu6YsYjs3QVcf0onrjupA9WqqHlbRRLNFsDTwA2Er/4tim0cEQnappw8JryRxrvpm0huXY+XrhpItxZ1g44lMRBNAch297djnkREAuXuvPb1Gu6ZtZj8whDjz+zGFUOS1LytAoumAMwzs4eAacCevSPdfUHMUolIqVq9dRfjpqXw6bKtDGzXkAdGJ5PUuFbQsSTGoikAAyOP/fYZ58DJJR9HREpTUch59pMVPDx3CVUqVeK+kb24sH8bNW+LE9H0AjqpNIKISOlauincvO27Nds5uWtT7h3Zkxb11LwtnhTXC+hid3/JzG480OvuHkhzOBE5OvmFIf71wTIenfc9dapX5ZELe3POMS3VvC0OFbcFsHcHYJ3SCCIisbdwzXbGTk0hY2MuZx/Tkolnd6eRmrfFreJ6AT0Rebyz9OKISCzszi/ir+8t5amPltO0TnWeurQfw7o3CzqWBExXAotUcJ8t28q4aSms2rqLiwYmMu6MrtStruZtogIgUmHl5BVw/+wMXvlyNW0b1eTfVw9kcAc1b5MfqQCIVEDvpW9iwhtpZOXmMeaE9twwrDM1EtTGQf5XNL2AmgH3AS3d/Qwz6w4cp779ImXP1h17uHNmOm8uXE+XZnV4/JK+9G5TP+hYUkZFswXwHPAsMD4yvBR4lXCPIBEpA9ydNxeuZ+Kbi9ixp5AbhnXm2qEdSKiiNg5ycNEUgMbu/pqZ3QLg7oVmpqZwImXEhuzdTJiexvsZWfRuU58Hz0+mczOdvS2HFk0B2GlmjYjcIN7MBgHZMU0lIocUCjmvfLWa+2dnUBgKMWFEN64Y0o7KauMgUYqmANwIvAl0MLNPgCbA+TFNJSLFWrllJ+OmpfD58m0M7tCISaOSSWxUM+hYUs5E0wtogZmdCHQhfDOYJe5eEPNkIvIThUUhnvlkBX+eu5SEKpV4YHQvft6vjdo4yBGJ5iygUfuN6mxm2UCqu2fFJpaI7G/xhhzGTk0hZW02p3Zvxj3n9aRZ3epBx5JyLJpdQL8GjgPmRYaHAp8TLgR3ufuLMcomIsCewiL+OW8Zj83LpF6Nqjx60bGM6NVC3/rlqEVTAEJAN3ffBP93XcC/CN8nYD6gAiASIwtW/8DYKSl8n7WDkce24vazutOgVkLQsaSCiKYAJO1d+UdkAZ3dfZuZ6ViASAzsyi/kz3OX8swnK2hetzrPXt6fk7o2DTqWVDDRFICPzOwt4PXI8GhgvpnVArbHKphIvPokcwvjpqWwZttuLh6UyNjhXamj5m0SA9EUgOsIr/SHED4L6AVgqrs7oLuFiZSQ7N0F3D97Mf/5ag3tGtfi1TGDGNi+UdCxpAKL5jRQB6ZEfkQkBuYu2shtM9LYsiOfa07swB+GdaJ6VTVvk9iK9jTQB4CmhLcAjHBdqBvjbCIV3ubcPUycuYhZKRvo2rwOT13an16t6wUdS+JENLuAHgTOdvfFsQ4jEi/cnTe+W8edM9PZtaeIP53Wmd+c2IGqldW8TUpPNAVgk1b+IiVn3fbdjJ+eygdLNtMnMdy8rWNTNW+T0hdNAfjazF4F3gD27B3p7tNiFUqkIgqFnJe/XM2k2YsJOdxxdncuPS5JzdskMNEUgLrALuC0fcY5oAIgEqXlm3cwbmoqX67cxvEdG3P/qF60aajmbRKsaM4CuqI0gohURIVFIZ76eAV/fXcp1apU4sHzk7mgb2u1cZAyIZqzgKoT7gfUA/i/zlPufmUMc4mUe+nrc7h56kLS1uVweo9m3H1uT5qqeZuUIdHsAnoRyABOB+4CfgXooLDIQeQVFPHofzN5/MNl1K+ZwL9+1YczerUIOpbIT0RzzllHd78N2OnuzwMjgF4lMXMzG25mS8ws08zGlcQ0RYIwfjzUrAk1E7fR848f8+i8TM7t3Yr3bjxBK38ps6LZAtjb8G27mfUENgJJRztjM6sM/BM4FVgLfGVmb7p7+tFOW6S0JCRAQQFY1ULqn7CEOn1XkpdTg2Zp/fnzJDVvk7ItmgIw2cwaALcRvjVkbeD2Epj3ACDT3ZcDmNl/gHMBFQApF/au/KsnbabR8FSq1NtNzjdt2T6/KxuLovnTEglWNGcBPRV5+iHQvgTn3QpYs8/wWsL3GPgfZjYGGAOQmJhYgrMXOTpFlQpodGY6tXutpWBrLTa+dBx71jUEIKFGwOFEohDNWUDVCHcDTdr3/e5+11HO+0DnwflPRrhPBiYD9OvX7yeviwRhTtoGWly1iMo188n+rAPbP+kERT82b7vhhgDDiUQpmu3UGUA28A37XAlcAtYCbfYZbg2sL8Hpi5S4rNw87pixiLfTNlI5vy4bp/Qnf9P/Nm+79Va4996AAoochmgKQGt3Hx6DeX8FdDKzdsA64ELgohjMR+SouTtTvlnLPbMWs7ugiJtO78KYE9rTbnYl1kXeU7067N4daEyRwxJNAfjUzHq5e2pJztjdC83sd8A7QGXgGXdfVJLzECkJa7bt4tbpqXz0/Rb6tm3AA6OT6di0NgBr1wYcTuQoHLQAmFkq4X3yVYArzGw54V1Ae+8HkHy0M3f32cDso52OSCyEQs4Ln63kwXeWAHDnOT24ZFBbKql5m1QQxW0BnFVqKUTKmMysHYybmsLXq37ghM5NuG9kT1o3UPM2qVgOWgDcfRWAmQ0CFrl7bmS4DtAdWFUqCUVKUUFRiMnzl/PIe99TI6EyD19wDKP7tFLzNqmQojkG8C+gzz7DOw8wTqTcS1uXzc1TUkjfkMOZvZoz8ZweNK2j5m1ScUVTACxyY3gA3D1kZrrMUSqMvIIiHnn/eybPX06Dmgk8fnEfhvdU/x6p+KJZkS83s+sJf+sH+C2wPHaRRErPVyu3MXZKCsu37OSCvq2ZMKI79WpWDTqWSKmIpgBcA/wdmED4rKD3ibRmECmvduwp5ME5Gbzw2Spa1a/BC1cO4ITOTYKOJVKqoukFlEX4Ii2RCuHDpZu5dVoq67N3c/ngJG46vQu1qmmvpsQf/a+XuLF9Vz53vZXOtAXr6NCkFlOuOY6+bRsGHUskMCoAEhdmp27g9hlpbN9VwO9O6sjvTu5I9aqVD/1BkQpMBUAqtKycPG6fsYg5izbSs1Vdnr9yAD1a1jv0B0XiQHGtIG4s7oPu/peSjyNSMtyd179Zyz1vpZNXGGLs8K5c/bN2VKkczV1QReJDcVsAdSKPXYD+hO8GBnA2MD+WoUSOxr7N2wYkNWTS6F60b1I76FgiZU5xrSDuBDCzuUCffVpBTAReL5V0IoehKNK87aF3lmDA3ef24FcD1bxN5GCiOQaQCOTvM5xPCdwUXqQkZWblcvOUFBas3s7QLk24d2QvWtXXfRlFihNNAXgR+NLMphO+EGwk8EJMU4lEqaAoxBMfLuPv72dSs1pl/vLzYxh5rJq3iUQjmgvB7jWzt4GfRUZd4e7fxjaWyKGlrs3mpikLydiYy4jkFtx5Tg8a164WdCyRciPa00BrAjnu/qyZNTGzdu6+IpbBRA4mr6CIv733PU9+tJxGtRJ44pK+nN6jedCxRMqdQxYAM7sD6Ef4bKBngarAS8CQ2EYT+akvV2xj7NQUVmzZyS/6teHWEd2oV0PN20SORDRbACOBY4EFAO6+PnJTGJFSk5tXwINzlvDi56to07AGL181kCEdGwcdS6Rci6YA5Lu7m5kDmFmtGGcS+R/zMrIYPz2VDTl5/Pr4dvzxtM7UTNBF7CJHK5q/otfM7AmgvpldDVwJPBXbWCKwbWc+d7+VzvRv19GpaW2mXjuYPokNgo4lUmFEcxbQw2Z2KpBD+DjA7e7+bsyTSdxyd2albuCOGYvI3l3A9ad04rqTOlCtipq3iZSkaA4CP+DuY4F3DzBOpERtysljwhtpvJu+ieTW9XjpqoF0a1E36FgiFVI0u4BOBfZf2Z9xgHEiR8zdefWrNdw7ezH5hSFuPbMrVw5R8zaRWCquG+i1hO//28HMUvZ5qQ7waayDSfxYtXUnt0xL5dNlWxnYriEPjE4mqbHONRCJteK2AP4NvA3cD4zbZ3yuu2+LaSqJC0Uh59lPVvDw3CVUqVSJ+0b24sL+bdS8TaSUFNcNNBvINrNHgG37dAOtY2YD3f2L0gopFc+SjbncPDWFhWu2c0rXptwzsict6ql5m0hpiuYYwL+APvsM7zzAOJGo5BeGeOyDTP45L5M61avyyIW9OeeYlmreJhKAaAqAubvvHXD3kJnpKhw5bAvXbOfmKSks2ZTLub1bcvtZ3Wmk5m0igYlmRb7czK4n/K0fwgeGl8cuklQ0u/OL+Mu7S3j64xU0rVOdpy7tx7DuzYKOJRL3oikA1wB/ByYQvh/A+8CYWIaSiuOzZVsZNy2FVVt3cdHARMad0ZW61dW8TaQsiOZK4CzgwlLIIhVITl4B98/O4JUvV9O2UU3+ffVABndQ8zaRsqS46wBudvcHzewfhL/5/w93vz6myaTcei99ExPeSCMrN4+rf9aOG0/tQo0EtXEQKWuK2wJYHHn8ujSCSPm3dcceJs5MZ+bC9XRpVofHL+lL7zb1g44lIgdR3HUAMyOPz5f0TM3sAmAi0A0Y4O4qMuWYu/PmwvXcOTOd3LwCbhjWmWuHdiChito4iJRlxe0CmskBdv3s5e7nHMV804BRwBNHMQ0pAzZk72bC9DTez8iid5v6PHh+Mp2b6X5BIuVBcbuAHo48jgKaE74NJMAvgZVHM1N3Xwzo4p9yLBRyXvlqNZNmZ1AQCjFhRDeuGNKOymrjIFJuFLcL6EMAM7vb3U/Y56WZZjY/5smkzFqxZSfjpqbwxYptDO7QiEmjkklsVDPoWCJymKK5DqCJmbV39+UAZtYOaHKoD5nZe4S3HPY33t1nRBvQzMYQue4gMTEx2o9JDBQWhXjmkxX8ee5SEipXYtKoXvyifxttyYmUU9EUgBuAD8xs79W/ScBvDvUhdx92FLn2nc5kYDJAv379DnpMQmJr8YYcxk5NIWVtNsO6NeOe83rSvF71oGOJyFGI5kKwOWbWCegaGZXh7ntiG0vKij2FRfxz3jIem5dJvRpVefSiYxnRq4W+9YtUANHcErImcCPQ1t2vNrNOZtbF3d860pma2UjgH4R3Jc0ys+/c/fQjnZ7ExoLVPzB2SgrfZ+1g5LGtuP2s7jSolRB0LBEpIdHsAnoW+AY4LjK8FngdOOIC4O7TgelH+nmJrV35hTz8zlKe/XQFzetW59nL+3NS16ZBxxKREhZNAejg7r8ws18CuPtu0/Z/hfVJ5hbGTUthzbbdXDwokbHDu1JHzdtEKqRoCkC+mdUgclGYmXUAdAyggsneXcB9sxbz6tdraNe4Fq+OGcTA9o2CjiUiMRRNAbgDmAO0MbOXgSHA5bEMJaVr7qKNTHgjja0787nmxA78YVgnqldV8zaRiq7YAmBmlYAGhK8GHgQY8Ht331IK2STGNufuYeLMRcxK2UDX5nV4+rL+9GpdL+hYIlJKii0Akds//s7dXwNmlVImiTF3543v1nHnzHR27Snij6d25pqhHahaWc3bROJJNLuA3jWzPwGvEr4hPADuvi1mqSRm1m3fzfjpqXywZDN9EuvzwOhkOql5m0hciqYAXBl5vG6fcQ60L/k4EiuhkPPyF6uY9HYGIYc7zu7OpcclqXmbSByL5krgdqURRGJn+eYdjJuaypcrt3F8x8bcP6oXbRqqeZtIvIvmSuDqwG+B4wl/8/8IeNzd82KcTY5SYVGIpz5ewV/fXUq1KpV48PxkLujbWm0cRASIbhfQC0Au4dYNEL4fwIvABbEKJUcvfX0ON09dSNq6HE7v0Yy7z+1J07pq3iYiP4qmAHRx92P2GZ5nZgtjFUiOTl5BEY/+N5PHP1xG/ZpVeexXfTijZ3N96xeRn4imAHxrZoPc/XMAMxsIfBLbWHIkvlm1jZunpLBs805G9WnFbSPUvE1EDi6aAjAQuNTMVkeGE4HFZpYKuLsnxyydRGXnnkIeemcJz3+2kpb1avD8lQM4sfMh79kjInEumgIwPOYp5IjNX7qZW6alsj57N5cMasvNw7tSu1o0/6wiEu+iOQ10VWkEkcOTvauAu2elM+WbtbRvUovXfnMc/ZMaBh1LRMoRfVUsh+akbeC2GYvYtjOf3w7twPWnqHmbiBw+FYByJCs3jztmLOLttI30aFmXZy/vT89Wat4mIkdGBaAccHemLljH3W+ls7ugiJuHd+Hqn7VX8zYROSoqAGXc2h92cev0NOYv3Uz/pAZMGp1Mhya1g44lIhWACkAZFQo5L36+igfmZABw17k9uHhgWyqpeZuIlBAVgDIoM2sH46am8PWqHzihcxPuG9mT1g3UvE1ESpYKQBlSUBRi8vzlPPL+99SoWpmHLziG0X1aqY2DiMSECkAZkbYum5unpJC+IYczezVn4jk9aFpHzdtEJHZUAAKWV1DEI+9/z+T5y2lYK4HHL+7D8J4tgo4lInFABSBAX63cxtgpKSzfspML+rZmwoju1KtZNehYIhInVAACsGNPIQ/OyeCFz1bRukENXvz1AH7WSc3bRKR0qQCUsg+WZDF+ehrrs3dz+eAkbjq9C7XUvE1EAqA1Tyn5YWc+d89KZ9qCdXRsWpsp1wymb9sGQccSkTimAhBj7s7s1I3c8WYa23cVcP3JHbnu5I5Uq6LmbSISLBWAGMrKyWPCG2nMTd9Er1b1eOHKgXRvWTfoWCIigApATLg7r3+9lrtnpZNfGOKWM7ry6+PbUUXN20SkDFEBKGFrtu3ilmmpfJy5hQHtGjJpVC/aq3mbiJRBKgAlpCjkPP/pSh56ZwmVKxn3nNeTiwYkqnmbiJRZKgAl4PtNuYydmsKC1ds5qUsT7h3Zi5b1awQdS0SkWCoAR6GgKMTjHyzjH//NpFa1yvztF705t3dLNW8TkXIhkAJgZg8BZwP5wDLgCnffHkSWI5Wydjs3T0khY2MuZyW3YOI5PWhcu1rQsUREohbUaSnvAj3dPRlYCtwSUI7DlldQxP2zF3PePz/hh135PHlpPx69qI9W/iJS7gSyBeDuc/cZ/Bw4P4gch+vz5Vu5ZVoqK7bs5ML+bbjlzG7Uq6HmbSJSPpWFYwBXAq8e7EUzGwOMAUhMTCytTP8jN6+ASW9n8PIXq2nTsAYvXzWQIR0bB5JFRKSkxKwAmNl7QPMDvDTe3WdE3jMeKARePth03H0yMBmgX79+HoOoxZqXkcWt01PZlJPHVce348bTOlMzoSzUTRGRoxOzNZm7DyvudTO7DDgLOMXdS33FfijbduZz18xFvPHdejo1rc1j1w7m2EQ1bxORiiOos4CGA2OBE919VxAZDsbdeStlAxPfXET27gJ+f0onfntSBzVvE5EKJ6h9GY8C1YB3I+fMf+7u1wSU5f9syslj/PQ03lu8iWNa1+PlqwfStbmat4lIxRTUWUAdg5jvwbg7r361hntnL6agKMT4M7tx5fHtqKw2DiJSgcX90cxVW3cybmoqny3fyqD2DZk0KpmkxrWCjiUiEnNxWwCKQs6zn6zg4blLqFqpEveN7MWF/duoeZuIxI24LABLNuZy89QUFq7Zzildm3LPyJ60qKfmbSISX+KqAOQXhnjsg0z+OS+TOtWr8siFvTnnGDVvE5H4FDcF4Ls12xk7JYUlm3I5t3dL7ji7Bw1rJQQdS0QkMBW+AOzOL+Iv7y7h6Y9X0LROdZ6+rB+ndGsWdCwRkcBV6ALw6bItjJuayuptu7hoYCLjzuhK3epq3iYiAhW0AOTkFXD/7Axe+XI1bRvV5JWrB3Fch0ZBxxIRKVMqXAF4L30T499IZXPuHn5zQnv+MKwzNRLUxkFEZH8VpgBs3bGHO2em8+bC9XRtXocnL+1Hcuv6QccSESmzyn0BcHfeXLieiW8uYseeQm48tTPXnNiBhCpB3exMRKR8KNcFYEP2biZMT+P9jCx6t6nPg+cn07lZnaBjiYiUC+WyAIRCzitfreb+2RkUhZzbzurO5YOT1LxNROQwlLsCsHLLTsZNS+Hz5dsY0rER949MJrFRzaBjiYiUO+WqAGzesYfT/zafhCqVeGB0L37er43aOIiIHKFyVQA2Zudxaecm3HNeT5rVrR50HBGRcs3K4O14D8rMNgOrAo7RGNgScIayQsviR1oWP9Ky+FFZWRZt3b3J/iPLVQEoC8zsa3fvF3SOskDL4kdaFj/SsvhRWV8WOlleRCROqQCIiMQpFYDDNznoAGWIlsWPtCx+pGXxozK9LHQMQEQkTmkLQEQkTqkAiIjEKRWAI2BmD5lZhpmlmNl0M6sfdKagmNkFZrbIzEJmVmZPd4slMxtuZkvMLNPMxgWdJyhm9oyZZZlZWtBZgmZmbcxsnpktjvx9/D7oTAeiAnBk3gV6unsysBS4JeA8QUoDRgHzgw4SBDOrDPwTOAPoDvzSzLoHmyowzwHDgw5RRhQCf3T3bsAg4Lqy+P9CBeAIuPtcdy+MDH4OtA4yT5DcfbG7Lwk6R4AGAJnuvtzd84H/AOcGnCkQ7j4f2BZ0jrLA3Te4+4LI81xgMdAq2FQ/pQJw9K4E3g46hASmFbBmn+G1lME/dAmOmSUBxwJfBBzlJ8pVM7jSZGbvAc0P8NJ4d58Rec94wpt6L5dmttIWzbKIYwdqR6tzqwUAM6sNTAX+4O45QefZnwrAQbj7sOJeN7PLgLOAU7yCX0xxqGUR59YCbfYZbg2sDyiLlCFmVpXwyv9ld58WdJ4D0S6gI2Bmw4GxwDnuvivoPBKor4BOZtbOzBKAC4E3A84kAbPwjUqeBha7+1+CznMwKgBH5lGgDvCumX1nZo8HHSgoZjbSzNYCxwGzzOydoDOVpsjJAL8D3iF8oO81d18UbKpgmNkrwGdAFzNba2a/DjpTgIYAlwAnR9YR35nZmUGH2p9aQYiIxCltAYiIxCkVABGROKUCICISp1QARETilAqAiEicUgGQMs/MhprZ4KOcxo7DeO9zZnb+0cyvpJjZp4f5/jKTXco+FQApD4YCR1UAyit3j8vfW0qHCoAEwszeMLNvIr3Sx+wzfriZLTCzhWb2fqSR1jXADZGLaX62/7fcvd/uzax25DMLzCzVzA7ZldPMLo3c12Ghmb24z0snmNmnZrZ877wONn0zS4r0fX8y8vvMNbMakdf6R6b/WeQ+EmmR8ZUjw19FXv/NQfLt/d2GmtkHZjYlci+KlyNXmxb3u51iZt9Gsj5jZtUi4yeZWXpkvg9Hxl1gZmmR5RCXrb3jkrvrRz+l/gM0jDzWIHxPgUZAE8KdNdvt956JwJ/2+exzwPn7DO+IPFYB6kaeNwYy+fFixx0HyNADWAI03m9+zwGvE/6C1J1wu+eDTh9IItwUsHfktdeAiyPP04DBkeeTgLTI8zHAhMjzasDXe3/v/TLu/d2GAtmEew1VInzF7fEHeP9zwPlA9ciy7BwZ/wLwB6Bh5Hfeu1zqRx5TgVb7jtNPxf/RFoAE5XozW0j4fgptgE6Eb5wx391XALj74faWN+A+M0sB3iPclrlZMe8/GZji7lsOML833D3k7un7TKO46a9w9+8iz78BkiJ3iqvj7nv34/97n+mfBlxqZt8RbhPciPAyKM6X7r7W3UPAd4QLz8F0iWRaGhl+HjgByAHygKfMbBSwt5fVJ8BzZnY1UPkQOaSCUDdQKXVmNhQYBhzn7rvM7APC31iN6FopFxLZfRnZDZIQGf8rwlsRfd29wMxWRqZ70CjFzG/Pfu871PT3fX8R4S2b4nbRGPD/3P1weiftP4/i/n4POG93LzSzAcAphBvX/Q442d2vMbOBwAjgOzPr7e5bDyOblEPaApAg1AN+iKz8uxL+5g/h3Ronmlk7ADNrGBmfS7j53l4rgb6R5+cCVfeZblZk5XwS0PYQOd4Hfm5mjfabX3G5o56+u/8A5JrZ3t/vwn1efge4NtIyGDPrbGa1DjH/w5FBeCukY2T4EuBDC/enr+fuswnvEuodmX8Hd//C3W8HtvC/La6lgtIWgARhDnBNZFfKEsK7gXD3zZEDwtPMrBKQBZwKzASmRA66/j/gSWCGmX1JeCW+MzLdl4GZZvY14V0kGcWFcPdFZnYv4RVjEfAtcHkxHzms6Uf8GnjSzHYCHxDejw/wFOFdOAsiWzGbgfOimF5U3D3PzK4AXjezKoTbVj9O+BjADDPbu8V1Q+QjD5lZp8i494GFJZVFyi51AxWJITOr7e57z+QZB7Rw998HHEsE0BaASKyNMLNbCP+traL4LQyRUqUtABGROKWDwCIicUoFQEQkTqkAiIjEKRUAEZE4pQIgIhKn/j8eXe9c2Uu+XgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of edges in consider 12431\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "low_limit = -2.5\n",
    "up_limit = 2.5\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 10, color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on edges of Cora dataset perturb edges:' + str(batch_edges))\n",
    "# plt.title('Influence function on Complete Node of Citeseer dataset')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "# plt.title('Influence function on Iris dataset')\n",
    "plt.show()\n",
    "print('Number of edges in consider %d'% len(f_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "19f1a125",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([33])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extra_index_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4540da32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.00321799,  0.00434749,  0.00231544, ..., -0.00225762,\n",
       "        0.00493145,  0.00822329])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "cb572cf8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.9211923910742819, pvalue=0.0)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "178ee63f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.920753682604643, pvalue=0.0)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7671e600",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([acctual_influence_1, predict_influence_1, f_l.numpy().astype(int), t_l.numpy().astype(int)]).T\n",
    "df.columns = ['acctual_influence', 'predict_influence', 'from_edges', 'to_edges']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "99ea11b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('result_data_adver/' + data_set + '_edge_influence.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7ea6c59b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/zizhang/Desktop/Projects/Project6_influence_function/graph_influence_function/experiments'"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5d2d4281",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "# df = pd.read_csv('result_data/pubmed_edge_influence.csv', header = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "81a4fe6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# import numpy as np\n",
    "# low_limit = -10\n",
    "# up_limit = 10\n",
    "# x = np.linspace(low_limit, up_limit)\n",
    "# # x = np.linspace(-0.2, 0.15)\n",
    "# plt.plot(x, x)\n",
    "# plt.scatter(df[0], df[1], s = 10, color='blue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "746bea42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# g = dgl.graph(([1, 1, 2, 3, 4], [0, 2, 0, 4, 2]))\n",
    "# g.edata['w'] = torch.arange(10).view(5, 2)\n",
    "# sg, inverse_indices = dgl.khop_in_subgraph(g, 0, k=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
