{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c5336d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import torch\n",
    "import dgl\n",
    "import cupy as cp\n",
    "import collections\n",
    "import random\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from torch_geometric.datasets import Planetoid \n",
    "from pygod.utils import gen_attribute_outliers\n",
    "from pygod.models import CoLA, DOMINANT\n",
    "from dgl.data import FraudDataset\n",
    "from pygod.utils.metric import eval_roc_auc, eval_recall_at_k, eval_precision_at_k\n",
    "from sklearn.metrics import precision_score, accuracy_score\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from sklearn.neighbors import NearestCentroid\n",
    "\n",
    "from outlier_generator import inject_edge_outlier, generate_node_feature_outliers\n",
    "from model_edge_influence import EdgeInfluenceSGC\n",
    "\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6db67b05",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f1a97a23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "data_dgl = 'cora'\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_dgl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "47b4890d",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_graph, edge_oulier_labels = inject_edge_outlier(graph, labels,num_edges=126)\n",
    "graph = new_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1248d4ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# l2_term = 0.01\n",
    "l2_term = 0.01\n",
    "# batch_edges = 1\n",
    "num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "94e330ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ee489d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0.numpy().astype(np.float32)\n",
    "train_y = labels.numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0.numpy().astype(np.float32)\n",
    "val_y = labels.numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = graph.nodes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e0bc341f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "# lr.fit(train_x, train_y, sample_weight=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c17e1ed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = from_indexes, to_indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5fb12df8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13516"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(from_indexes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c06bafca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# train the original data\n",
    "# calculate the hessian matrix\n",
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "09dcbcd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 13516/13516 [31:11<00:00,  7.22it/s]\n"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "# for k in tqdm(range(len(train_node_idx))):\n",
    "for k in tqdm(range(len(f_l))):\n",
    "# for k in tqdm(range(len(new_index))):\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[new_index[k]], to_index=t_l[new_index[k]])\n",
    "#     eis.remove_edges_sgc_from_influence()\n",
    "    eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "    eis.remove_edges_sgc_from_influence()\n",
    "    feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "#     node_id = train_node_idx.numpy()[k]\n",
    "#     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "#     nis.remove_edges_sgc()\n",
    "#     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    if extra_index_train == []:\n",
    "        predict_influence_1.append(0.0)\n",
    "        acctual_influence_1.append(0.0)\n",
    "        continue\n",
    "    \n",
    "    feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[perturb_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "#     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "#     weight_1 = np.ones(len(train_x_orig))\n",
    "#     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "#     weight_2 = np.ones(len(train_x_orig))\n",
    "#     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "#     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "#     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "#     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "#     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "#     assert(np.allclose(train_x_delete_1, train_x))\n",
    "#     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "#     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "#     logits_val_y_new_1 = val_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "#     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_val_y_new_1, one_hot_labels_val)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "#     train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "#     train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "#     lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "#     logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "#     new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "#     acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeca9fda",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c79970f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([13516])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_graph.edges()[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2a81ed93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAS7UlEQVR4nO3df5Dcd33f8edLpx/+gYOc+EyEJEdqELZVAsG5CqchGcqPieTSCNrQkdNiYtJRXXBiOnSCCFM6aaZtaDtp6rFjxyVuMGHwQGBqhQocJyk4HWzQCf/AQigcwuCzhX02/gG4Rr/e/eM2cDrO1kq7p93T5/mYWWm/3+/nu/v6zN3sa/d7u99NVSFJateiQQeQJA2WRSBJjbMIJKlxFoEkNc4ikKTGLR50gBNxzjnn1Jo1awYdQ5IWlF27dj1aVaOz1y/IIlizZg3j4+ODjiFJC0qSr8+13kNDktQ4i0CSGmcRSFLjLAJJapxFIEmNswgkqXHNFMH+rz3Mv/6Ff8uvnv8bfOajdww6jiQNjWaK4Lf/yX9l92f38uBX9vOff/UaHnng0UFHkqSh0EwRTE0+Rh2Z/u6FkZFFfGv/4wNOJEnDoZkieNM7/xHLzljG6Wedxgtf9OO86OVrBx1JkobCgjzFxInY8q43ctFrX8qTU0/xslf9XRYvaWbqkvScmno0fPHP/OSgI0jS0Gnm0JAkaW4WgSQ1ziKQpMZZBJLUOItAkhpnEUhS4ywCSWqcRSBJjetLESTZmGRvkokk2+bYniRXd7bfm+SiWdtHktyV5BP9yCNJ6l7PRZBkBLgW2ASsBy5Nsn7WsE3Aus5lK3DdrO1XAXt6zSJJOn79eEWwAZioqn1VdQC4Gdg8a8xm4KaadiewPMkKgCSrgH8IvL8PWSRJx6kfRbASeGDG8mRnXbdjfh/4TeDIc91Jkq1JxpOMT01N9RRYkvQD/SiCzLGuuhmT5PXAI1W161h3UlU3VNVYVY2Njo6eSE5J0hz6UQSTwOoZy6uAh7oc83PALyW5n+lDSq9O8id9yCRJ6lI/imAnsC7J2iRLgS3A9lljtgOXdd49dDHwZFXtr6p3V9WqqlrT2e+vquqf9yGTJKlLPX8fQVUdSnIlcCswAtxYVbuTXNHZfj2wA7gEmACeBi7v9X4lSf2RqtmH84ff2NhYjY+PDzqGJC0oSXZV1djs9X6yWJIaZxFIUuMsAklqnEUgSY2zCCSpcRaBJDXOIpCkxlkEktQ4i0CSGmcRSFLjLAJJapxFIEmNswgkqXEWgSQ1ziKQpMZZBJLUOItAkhpnEUhS4ywCSWqcRSBJjbMIJKlxFoEkNc4ikKTGWQSS1DiLQJIa15ciSLIxyd4kE0m2zbE9Sa7ubL83yUWd9acl+XySe5LsTvLb/cgjSepez0WQZAS4FtgErAcuTbJ+1rBNwLrOZStwXWf994BXV9XLgJ8GNia5uNdMkqTu9eMVwQZgoqr2VdUB4GZg86wxm4GbatqdwPIkKzrL3+mMWdK5VB8ySZK61I8iWAk8MGN5srOuqzFJRpLcDTwC3FZVn5vrTpJsTTKeZHxqaqoPsSVJ0J8iyBzrZj+rf9YxVXW4qn4aWAVsSPKSue6kqm6oqrGqGhsdHe0lryRphn4UwSSwesbyKuCh4x1TVU8AnwY29iGTJKlL/SiCncC6JGuTLAW2ANtnjdkOXNZ599DFwJNVtT/JaJLlAElOB14LfLkPmSRJXVrc6w1U1aEkVwK3AiPAjVW1O8kVne3XAzuAS4AJ4Gng8s7uK4APdN55tAj4SFV9otdMkqTupWrhvUlnbGysxsfHBx1DkhaUJLuqamz2ej9ZLEmNswgkqXEWgSQ1ziKQpMZZBJLUOItAkhpnEUhS4ywCSWqcRSBJjbMIJKlxFoEkNc4ikKTGWQSS1DiLQJIaZxFIUuMsAklqnEUgSY2zCCSpcRaBJDXOIpCkxlkEktQ4i0CSGmcRSFLjLAJJapxFIEmN60sRJNmYZG+SiSTb5tieJFd3tt+b5KLO+tVJ/k+SPUl2J7mqH3kkSd3ruQiSjADXApuA9cClSdbPGrYJWNe5bAWu66w/BLyzqi4ELgbePse+kqR51I9XBBuAiaraV1UHgJuBzbPGbAZuqml3AsuTrKiq/VX1BYCq+jawB1jZh0ySpC71owhWAg/MWJ7khx/MjzkmyRrg5cDn5rqTJFuTjCcZn5qa6jWzJKmjH0WQOdbV8YxJ8jzgY8A7quqpue6kqm6oqrGqGhsdHT3hsJKko/WjCCaB1TOWVwEPdTsmyRKmS+BDVfXxPuSRJB2HfhTBTmBdkrVJlgJbgO2zxmwHLuu8e+hi4Mmq2p8kwB8Be6rq9/qQRZJ0nBb3egNVdSjJlcCtwAhwY1XtTnJFZ/v1wA7gEmACeBq4vLP7zwFvBr6Y5O7Out+qqh295pIkdSdVsw/nD7+xsbEaHx8fdAxJWlCS7Kqqsdnr/WSxJDXOIpCkxlkEktQ4i0CSGmcRSFLjLAJJapxFIEmNswgkqXEWgSQ1ziKQpMb1fK6hheLAMwf42H/7BN/65hNsfvtGVr34hYOOJElDoZki+I+/8t/Z+am7OPC9g/zFB2/npq9ew1lnP2/QsSRp4Jo5NHTPZ3Zz4JmDUHDkyBEe+PKDg44kSUOhmSJ4+Wt+iqWnLSGLwqKRRZx34apBR5KkodDMoaFtH/wNbrnmkzz+8JO8/l++juctP3PQkSRpKDRTBEuXLeFN7/ylQceQpKHTzKEhSdLcLAJJapxFIEmNswgkqXEWgSQ1ziKQpMY18/bRgwcO8mfX/zlPPPwkm/7Fa1ix9gWDjiRJQ6GZInjfW67hjlt2cujAIT7xh7fxwa9ew5nP90NlktTMoaFdf34PB545yJEjxeFDh/n6lyYHHUmShkJfiiDJxiR7k0wk2TbH9iS5urP93iQXzdh2Y5JHktzXjyzP5qd+/kKWnLYEAklYfcHK+bw7SVowej40lGQEuBZ4HTAJ7Eyyvaq+NGPYJmBd5/IK4LrO/wB/DFwD3NRrlufyng+/g4/8l+1865tP8IZf3+QpqCWpox9/I9gATFTVPoAkNwObgZlFsBm4qaoKuDPJ8iQrqmp/Vd2eZE0fcjynZacv483vfdN8340kLTj9ODS0EnhgxvJkZ93xjnlOSbYmGU8yPjU1dUJBJUk/rB9FkDnW1QmMeU5VdUNVjVXV2Ojo6PHsKkl6Dv0ogklg9YzlVcBDJzBGkjQA/SiCncC6JGuTLAW2ANtnjdkOXNZ599DFwJNVtb8P9y1J6lHPRVBVh4ArgVuBPcBHqmp3kiuSXNEZtgPYB0wA/wN429/un+TDwB3A+Ukmk/xar5kkSd3L9Bt5FpaxsbEaHx8fdAxJWlCS7Kqqsdnrm/lksSRpbhaBJDXOIpCkxlkEktQ4i0CSGmcRSFLjLAJJapxFIEmNswgkqXHNFMGjDz7Ge17/n3jb2G/y+U/eNeg4kjQ0mimCd2/6D3x+xxf4yhe+xnvf8D4effCxQUeSpKHQTBHcf98Pvhfn8MHD3P3p3QNMI0nDo5kimO3/feeZQUeQpKHQTBEsXjpy1PJLXnnBgJJI0nBppggOHzx81PLMQ0WS1LJmimD21y4cOXxkMEEkacg0UwSzPfXYtwcdQZKGQrNFsGhRBh1BkoZCs0Vw+5/eOegIkjQUmi2C7z719KAjSNJQaLYI1v/s+YOOIElDodki2H3H3kFHkKSh0GwR7Lvr/kFHkKSh0GwRSJKmLR50AEnSsdX3/i/1vb8mSzeQ017T19vuyyuCJBuT7E0ykWTbHNuT5OrO9nuTXNTtvpLUuiNP76Aefys8/T+pJ/4VR779J329/Z6LIMkIcC2wCVgPXJpk/axhm4B1nctW4Lrj2FeS2vbUO45e/u6/7+vN9+MVwQZgoqr2VdUB4GZg86wxm4GbatqdwPIkK7rcV5I0j/pRBCuBmafynOys62ZMN/sCkGRrkvEk41NTUz2HliRN60cRzHXSnupyTDf7Tq+suqGqxqpqbHR09DgjSpKeTT/eNTQJrJ6xvAp4qMsxS7vYV5I0j/rximAnsC7J2iRLgS3A9lljtgOXdd49dDHwZFXt73JfSWrb8ltmLX+orzff8yuCqjqU5ErgVmAEuLGqdie5orP9emAHcAkwATwNXP5c+/aaSZJOJYtOu5A6dxd1cA9Zso4sOruvt9+XD5RV1Q6mH+xnrrt+xvUC3t7tvpKko2XRWWTZhnm57WZPMbHk9CWDjiBJQ6HZIlh2mkUgSdBwEfzIOWcNOoIkDYVmi+DMHzlj0BEkaSg0WwQbLnn5oCNI0lBotgge+PL+QUeQpKHQbBGcu8bTVEgSNFwEL33lBYOOIElDoZkiWLLs6M/Onbn8zAElkaTh0kwR/NjKHz1q+bwLVw0oiSQNl2aK4OH7j/4Ogy/cdu+AkkjScGmmCGZ/y8EyTzEhSUBDRbD09KXfv75oZBHLX7B8cGEkaYg0UwQXvOJF378+smSEn1jv3wgkCRoqggf/5gcfIEvC1DceHWAaSRoezRRBFv3g65GTv/1HktRMEfybP3oby85YxsjiRWx866tZ+5LzBh1JkoZCX76hbCG46LUv5ZYnPsDBA4c47Yxlg44jSUOjmSIAGFk8wsjikUHHkKSh0syhIUnS3CwCSWqcRSBJjbMIJKlxFoEkNc4ikKTG9VQESX40yW1JvtL5/+xnGbcxyd4kE0m2zVj/piS7kxxJMtZLFknSien1FcE24C+rah3wl53loyQZAa4FNgHrgUuTrO9svg/4x8DtPeaQJJ2gXotgM/CBzvUPAG+YY8wGYKKq9lXVAeDmzn5U1Z6q2ttjBklSD3otghdU1X6Azv/nzjFmJfDAjOXJzjpJ0hA45ikmkvwF8ONzbHpPl/cx12k+a451x8qxFdgKcN55njBOkvrlmEVQVa99tm1JHk6yoqr2J1kBPDLHsElg9YzlVcBDxxu0qm4AbgAYGxs77iKRJM2t10ND24G3dK6/BbhljjE7gXVJ1iZZCmzp7CdJGgK9FsHvAq9L8hXgdZ1lkrwwyQ6AqjoEXAncCuwBPlJVuzvj3phkEvhZ4H8nubXHPJKk45SqhXeUZWxsrMbHxwcdQ5IWlCS7quqHPrPlJ4slqXEWgSQ1ziKQpMZZBJLUuKaK4NuPf4eHvz7FQvwDuSTNl2aK4DMfvYMtK7fy1guv4nf+6e9ZBpLU0UwR/MFVN3LgmYMceOYgOz91F1+9+/5BR5KkodBMESw7Y+n3rx85UkctS1LLmimCd3/oHSw/9/ksXrqYLdvewOrzPQGqJEEXJ507VVz4inV89JvvH3QMSRo6zbwikCTNzSKQpMZZBJLUOItAkhpnEUhS4ywCSWqcRSBJjVuQ31CWZAr4+gnufg7waB/jLATOuQ3OuQ29zPknqmp09soFWQS9SDI+11e1ncqccxuccxvmY84eGpKkxlkEktS4FovghkEHGADn3Abn3Ia+z7m5vxFIko7W4isCSdIMFoEkNe6ULYIkG5PsTTKRZNsc25Pk6s72e5NcNIic/dTFnP9ZZ673JvlskpcNImc/HWvOM8b9vSSHk/zyyczXb93MN8mrktydZHeSz5zsjP3Wxe/185P8WZJ7OnO+fBA5+ynJjUkeSXLfs2zv7+NXVZ1yF2AE+Crwd4ClwD3A+lljLgE+CQS4GPjcoHOfhDn/feDszvVNLcx5xri/AnYAvzzo3PP8M14OfAk4r7N87qBzn4Q5/xbwvs71UeBbwNJBZ+9x3r8AXATc9yzb+/r4daq+ItgATFTVvqo6ANwMbJ41ZjNwU027E1ieZMXJDtpHx5xzVX22qh7vLN4JrDrJGfutm58zwK8DHwMeOZnh5kE38/0V4ONV9Q2AqmphzgWclSTA85gugkMnN2Z/VdXtTM/j2fT18etULYKVwAMzlic76453zEJyvPP5NaafUSxkx5xzkpXAG4HrT2Ku+dLNz/jFwNlJPp1kV5LLTlq6+dHNnK8BLgQeAr4IXFVVR05OvIHp6+PXqfqdxZlj3ez3yXYzZiHpej5J/gHTRfDKeU00/7qZ8+8D76qqw9NPGBe0bua7GPgZ4DXA6cAdSe6sqr+Z73DzpJs5/yJwN/Bq4CeB25L8dVU9Nc/ZBqmvj1+nahFMAqtnLK9i+tnC8Y5ZSLqaT5KXAu8HNlXVYycp23zpZs5jwM2dEjgHuCTJoar6XyclYX91+3v9aFV9F/huktuBlwELtQi6mfPlwO/W9MHziSRfAy4APn9yIg5EXx+/TtVDQzuBdUnWJlkKbAG2zxqzHbis89f3i4Enq2r/yQ7aR8ecc5LzgI8Db17AzxBnOuacq2ptVa2pqjXAnwJvW6AlAN39Xt8C/HySxUnOAF4B7DnJOfupmzl/g+lXQCR5AXA+sO+kpjz5+vr4dUq+IqiqQ0muBG5l+l0HN1bV7iRXdLZfz/Q7SC4BJoCnmX5WsWB1Oef3Aj8G/EHnGfKhWsBnbuxyzqeMbuZbVXuSfAq4FzgCvL+q5nwL4kLQ5c/4d4A/TvJFpg+ZvKuqFvSpqZN8GHgVcE6SSeDfAUtgfh6/PMWEJDXuVD00JEnqkkUgSY2zCCSpcRaBJDXOIpCkxlkEktQ4i0CSGvf/Aes4Wh/w7pUvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of edges in consider 13516\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "# plt.plot(x, x)\n",
    "plt.scatter(edge_oulier_labels, predict_influence_1, s = 10, c=edge_oulier_labels)\n",
    "# plt.xlabel('actual change in loss')\n",
    "# plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on edges of Cora dataset perturb edges:' + str(batch_edges))\n",
    "# plt.title('Influence function on Complete Node of Citeseer dataset')\n",
    "# plt.ylim(low_limit, up_limit)\n",
    "# plt.xlim(low_limit, up_limit)\n",
    "# plt.title('Influence function on Iris dataset')\n",
    "plt.show()\n",
    "print('Number of edges in consider %d'% len(f_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "64a16c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "old_graph = graph.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ae697d0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13516"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(graph.edges()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "04e51317",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "unexpected EOF while parsing (3554093118.py, line 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  File \u001b[0;32m\"/tmp/ipykernel_6656/3554093118.py\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m    np.concatenate([np.repeat(0, 13264), np.repeat(1, 252)]\u001b[0m\n\u001b[0m                                                           ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m unexpected EOF while parsing\n"
     ]
    }
   ],
   "source": [
    "np.concatenate([np.repeat(0, 13264), np.repeat(1, 252)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9589adf1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.82944748e-04],\n",
       "       [-1.69400380e-04],\n",
       "       [-2.39681068e-04],\n",
       "       ...,\n",
       "       [ 6.52064262e-05],\n",
       "       [-3.58584247e-05],\n",
       "       [ 6.01029250e-05]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_infl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9612fce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(predict_influence_1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
