{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bfcba399",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import os\n",
    "import dgl\n",
    "from dgl import function as fn\n",
    "from dataset import load_graph_dataset\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "import tensorflow.compat.v1 as tf\n",
    "from graph_neural_networks import SGC_layer1, SGC_layer2\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from scipy.special import softmax, log_softmax\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from tqdm import tqdm\n",
    "import cupy as cp\n",
    "\n",
    "from dgl.data import RedditDataset\n",
    "from dataset_wikics_amazon import load_wikics, load_amazon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "46d099e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_wikics()\n",
    "# mask_idx = 0\n",
    "# train_mask = train_mask[mask_idx]\n",
    "# val_mask = val_mask[mask_idx]\n",
    "# l2_regularlization_term = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a692bd97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_amazon(dataname='computer', seed = 1)\n",
    "# l2_regularlization_term = 0.05\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4d9c85c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_amazon(dataname='photo', seed = 1)\n",
    "l2_regularlization_term = 0.05\n",
    "num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abdf24cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a344c7d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True, True,  ..., True, True, True])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a744a09d",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e5ea6461",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "test_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "test_y = labels[val_mask].numpy().astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe5478f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8447e869",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "559bfe67",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 159/159 [00:00<00:00, 534.64it/s]\n"
     ]
    }
   ],
   "source": [
    "train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "                                             one_hot_labels_train, l2_reg=True)\n",
    "val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, \n",
    "                                                   one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "# hess = fast_hess_cuda(train_x, logits_train_y)\n",
    "hess = lr.hess_cuda(train_x, logits_train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "36bd1693",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, a = lr.grad_one_batch(train_x, logits_train_y, one_hot_labels_train, l2_reg=True, batch_index=np.array([1, 2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ccfdfc5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.73565354,  0.87005292,  0.66535825, ..., -2.79608765,\n",
       "         3.28151389, -1.96517005]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_loss_total_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdcedbfc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "144e6050",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=True)\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=False)\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "\n",
    "pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "pred_infl = list(pred_infl.reshape(-1))\n",
    "#\n",
    "num_train = len(train_x)\n",
    "act_infl = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bc833cde",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 160/160 [00:23<00:00,  6.81it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(num_train)):\n",
    "    lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "    train_x_new = np.delete(train_x, i, axis = 0)\n",
    "    train_y_new = np.delete(train_y, i)\n",
    "    lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "    logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "    new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test, l2_reg = True)\n",
    "    act_infl.append(new_ori_val_loss - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "dd0ecd4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(90.59952319029907, 88.69704634731524)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_ori_val_loss, ori_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1aed5717",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEWCAYAAABv+EDhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwVElEQVR4nO3dd7wU9b3/8debXoVQlXpogmjUyLFHLxErttxc488EWzQhahJRU0wkETQh8Wo08SaxEPWqiGnGRMWOij0aUBSpIl2kd5B2zuf3x8xe1mV3z5xzds9s+Twfj33s7szszGd2Z+cz853vfL8yM5xzzpWfRnEH4JxzLh6eAJxzrkx5AnDOuTLlCcA558qUJwDnnCtTngCcc65MlWUCkNRV0iuSNku6VdJYSQ/FHVdUqfE38LK3SOrbkMvMNUlPS7oo7jjqS5JJ6h93HIVO0iJJJ8YdRyEqmQRQyx95JLAG2MfMvp/HsPKlQeKXNEXSN5OHmVkbM1uQr2XmgqRmYVL/UNLWcNu4T1IFgJmdZmYPhNNeLOm1WAPOM0kVYbJoUujLkTQ0nMcfUoa/JuniegdZ+1iqw4OexOOJHMxzWa5irK+SSQC11BuYZcV7F1yxx59vjwBnAV8H2gGHANOAYXEGFVW+d9RFYCtwYSJhx2x5eNCTeJwZZzA53zbMrCQewCLgxPD1xcBrwK+B9cBC4LRw3P3ALmAnsAU4ERgLPBSOHwosyzLvRsCPgY+AtcBfgQ7huArAgIuAJQRH6aOT5tMYuC787GaCnVLPcNwg4HlgHTAXODfDeqaL/37gF0nTfGYdwvh/ALwPbAT+ArRIGn82MB3YFMZ2KjAOqAK2h8v5fTitAf3D1+2AB4HVwGLgp0Cjmn6DDOt1ADAF2ADMBM5KWec/AE+G39tbQL8M8zkR+DTxvWaYZgrwzXCZ28P13AJsCMc3D+NeAqwE7gJahuM6AZPCONcBryatczfg7+H3sRC4MmmZUbabS8NlvpIh7h8CnwDLgUtSfovTgXfD33ApMDbpc0vCabeEj6OBfsCLYSxrgIlA+6TPXAt8HH7fc4FhEdZjr+XU4X88FFgG/A7436ThrwEXJ8XwU4JtbhXBNtguadoLwnFrgdFE/P9miiXDuKOAN8Lt4D1gaNK4bwCzw+9uAfDtcHhrgm2zOuk76ka0/++1BP/fHUCTbMuv1fdd3x1voTzYOwHsAr5FsNO9PPzTKByf+oWPJXoCuAr4F9CDYEdxN/CnlD/yH4GWBEeeO4ADkv7AM4CBgMLxHcMNY2m44TQBDiP4Ux6YYV1T44+yAb0dbmwdwo3zsnDcEQRJ4SSCP0d3YFA4bgrwzZRlJ+90HgQeA9qG6z4PuDTKb5Ayz6bAfILk2Aw4geDPMzBp/daFsTYh2Fn9OcN3cxPwcg3byv+tVxjnaynjfws8Hn5XbYEngF+F435FkBCaho/jwt+yEUFCvz5ch74Ef/5TarHdPBhuCy3TxHwqQTI6KJzm4ZTfYijw+TCOg8Npv5wy/yZJ8+sf/ubNgc7AK8Bvw3EDCbbHbkmf71eL9WiS7fuv4bcZSpAA9iVIZoltIDkBXBJuL32BNsCjwIRw3GCCHevxYXy3AbuJ8P/NFEua4d0Jksfw8Ps+KXzfORx/OkGCFfAfwDbgsCz7l/up+f87HehJsF/Juvxafd/12ekW0oO9E8D8pHGtwg1z3wxf+FiiJ4DZhEdD4fv9CHZ0TZL+AD2Sxr8NnBe+ngucnSb2/we8mjLsbmBMhnVNjT/KBnR+0vubgbuSlvObDMuZQoYEQLBT3wEMThr3bWBKlN8gZZ7HASsIj6TDYX8iPIoN1++epHHDgTkZYv4jGZJDuvUiJQEQ/Gm3knSGQXDEvDB8fSNB0uufMs8jgSUpw35CeBQbcbvpmyXm+4Cbkt7vT1ICSDP9bxO/KxF2zMCXgXfD1/0JjqxPBJqmTBdlPeqdAJK207+Er5MTwAvAFUmfGZgUw/XJvz9BstxJhP9vhliqCY6yE49zCY7GJ6RM+yxwUYZ1+icwKt1/M2n7run/e0nS+1otP9ujlMsaVyRemNk2SRAcLdRXb+AfkqqThlUBXdMtmyD7J5bbk+DUM908j5S0IWlYE2BCvaPNHFO3pJieqsP8OhEc6S5OGraY4Ohkr2XW8Bt0A5aaWfJ3mnFefPY7TbWWYOdYV50JktW0MF4IkkLj8PUtBAcMz4Xjx5vZTQS/YbeU37AxQRERRNtulmaJqxvBGUZC8veOpCMJzn4OIvhdmgN/yzQzSV2A/yFIvm0JjiTXA5jZfElXhet5oKRngWvMbHnE9chI0pakt4PNbEmWyf8b+EjSISnDu7H3dtckjKEbSd+jmW2VtDZp2mzxf5wmhuVm1iNlHe4Aviop+XpAU+ClcPxpwBiC7bARwfY0I8t6RpG8bfTOtvzaKNeLwNlsJfjBAJDUmGCnkLCUoCy7fdKjhZml23hSLSU4NUw3/OWUebYxs8vrEjPB6XNUmWKC4GgukzUER069k4b1Iv2fqCbLgZ6SkrfHus5rMnCEpB41ThlIXcc1BOW0Byb9Fu3MrA2AmW02s++bWV/gTOAaScMIvseFKb9hWzMbHs43ynaT7fv+hCBZJ/RKGf8wQbFVTzNrR1BMlchg6eb7q3D4wWa2D3B+0vSY2cNm9kWC39cIdsY1rUe2+BPzTb6gmm3nj5mtJTiT+XnKqEQiSuhFUMyzkpTvSVIrgmLWhPr8f5PnMSFlHq3N7CZJzQmuA/0a6Gpm7QkOsLL9FlH+v8mfy7j8WqwD4AkgnXlAC0mnS2pKcLGpedL4u4BxknoDSOos6eyI874H+LmkAQocLKkjwUXF/SVdIKlp+Dhc0gER5zsdGC6pg6R9Cco5o7oX+IakYZIaSeouaVA4biVBOetezKyK4ALaOEltw+/jGqAu91O8RfAn+FG47kMJdq5/ru2MzGwywcX0f0gaIqlJGN9lki5J85GVQA9JzcLPVxMUI/0mPEom/E5OCV+fIam/gsP/TQRHj1UERX2bJF0rqaWkxpIOknR4uJz6bDcQfNcXSxoc7tTGpIxvC6wzs+2SjiCoAZWwmqAoo2/K9FuADZK6E1yfIoxtoKQTwp3ZdoKEWBVhPdItp75uA44huGCf8Cfgakl9JLUBfklQVLSboAbYGZK+GP6mN/LZ/Vx9fwcItvEzJZ0S/s4twuqdPdhz9rUa2B2eDZyc9NmVQEdJ7ZKGTad2/99sy68VTwApzGwjcAXBzvpjgh1Tcr3d2wmOtJ6TtJnggtKREWd/G8Ef+TmCnce9BBf8NhNsJOcRHN2sIDjiap5hPqkmENQEWBTO+y8RP4eZvU1w8fk3BBeDX2bP0dXtwDmS1kv6nzQf/x7B97OAoIz2YYKy6loxs50E1TZPIzgCvwO40Mzm1HZeoXMIjrr+QrBOHwCVBGcHqV4kqHW0QtKacNi1BBcZ/yVpU/i5geG4AeH7LcCbwB1mNiVMiGcChxLUAFpDsA0l/uj12W4ws6cJjoZfDGN7MWWSK4Abw3lfT7CdJT67jaBW1+uSNkg6CriBoLLBRoLaVY8mzas5QXHSGoJtsQvBBfqs65FhOfViZpsIrgV0SBp8H8E2/wrBd72dYFvEzGYC3yHYFj8hKNbK1f83EdNSgppz1xHs6JcSJNBG4X/5SoLvfz1BIn486bNzCBLYgvA76kYt/7/Zll+b9YA9tWKcc86VGT8DcM65MuUJwDnnypQnAOecK1OeAJxzrkwV1Y1gnTp1soqKirjDcM65ojJt2rQ1ZtY5dXhRJYCKigqmTp0adxjOOVdUJC1ON9yLgJxzrkx5AnDOuTLlCcA558qUJwDnnCtTngCcc65MxZoAJLWX9IikOZJmSzo6znicc66cxF0N9HbgGTM7J2y6tVVNH3DOOZcbsZ0BSNqHoN/OeyFoEtjMNsQVjysei9Zs5ZZn57B1x+64Q3GuqMVZBNSXoC3r/5X0rqR7JLVOnUjSSElTJU1dvXp1w0fpCs7PJ83igTcWewJwrp7iTABNCDqkuNPMvkDQsciPUycys/FmVmlmlZ0773UnsyszL81ZxQtzVnHlsP502adF3OE4V9TiTADLgGVm9lb4/hGChOBcWjt2V3HjpFn07dSai4/pE3c4zhW92BKAma0AlkpKdLU3DJgVVzyu8P3v64tYuGYr1585mGZNvAazc/UVdy2g7wETwxpACwj6pnVuLys3bed3L3zIiQd0ZejALnGH41xJiDUBmNl0gs66ncvqpqfnsKvK+NkZB8QdinMlw8+jXcGbumgd/3j3Y0Ye35feHfeqKOacqyNPAK6gVVUbYx6fyX7tWnDFl/rFHY5zJcUTgCtof/n3UmYu38R1ww+gVbO4L1k5V1o8AbiCtWHbTm55dg5H9OnAGQfvF3c4zpUcTwCuYP3m+Xls/HQXN5x1IJLiDse5kuMJwBWkOSs2MeFfizn/qN4csN8+cYfjXEnyBOAKjpkx5rGZtGvZlGtO2j/ucJwrWZ4AXMF5csYnvLVwHd8/eSDtWzWLOxznSpYnAFdQtu3czbgnZ3Ngt3342hG94g7HuZLm9epcQblzykd8snE7v/vaF2jcyC/8OpdPfgbgCsaStdu4+5UFfPnQblRWdIg7HOdKnicAVzB+/uQsmjQSPxnu7f041xA8AbiC8PK81Tw/ayXfO2EAXb2jF+cahCcAF7udu6u54YmZ9OnUmku+WBF3OM6VDU8ALnb3v7GQBau3cv0Zg2nepHHc4ThXNjwBuFit2rSd2yd/yLBBXfjSIO/oxbmG5AnAxeq/n5kbdvQyOO5QnCs7ngBcbKYtXs/f31nGpcf1oaKTd/TiXEPzBOBiUV1t3PDETLru05zvfql/3OE4V5Y8AbhY/HXqUt5ftpHrhh9A6+Z+Q7pzcfAE4Brcxm27uPnZuRxe8TnOOqRb3OE4V7Y8AbgG95vJ89iwbSdjvaMX52LlCcA1qLkrNjPhX4s574heHNitXdzhOFfWPAG4BmNmjH18Jm2aN+EHJw+MOxznyl7sCUBSY0nvSpoUdywuv57+YAVvLljL90/enw6tvaMX5+IWewIARgGz4w7C5denO6sY9+RsBu3blq97Ry/OFYRYE4CkHsDpwD1xxuHy766XP+LjDZ8y9qwDadK4EI47nHNx/xN/C/wIqM40gaSRkqZKmrp69eoGC8zlztJ127jr5Y8485BuHNW3Y9zhOOdCsSUASWcAq8xsWrbpzGy8mVWaWWXnzp0bKDqXS+OenE0jieuGD4o7FOdckjjPAI4FzpK0CPgzcIKkh2KMx+XBax+u4ZmZK/juCf3Zr13LuMNxziWJLQGY2U/MrIeZVQDnAS+a2flxxeNyb1dVNWOfmEmvDq249It94g7HOZci7msAroQ98MYi5q/awvVnDKZFU+/oxblCUxCtcJnZFGBKzGG4HFq9eQe3T/6QoQM7M+wA7+jFuULkZwAuL25+Zg7bd1fxszMGe3s/zhWoGhOApK9Kahu+/qmkRyUdlv/QXLGavnQDf5u2jEuO7UO/zm3iDsc5l0GUM4CfmdlmSV8ETgEeAO7Mb1iuWFVXG2Me+4DObZvzvWED4g7HOZdFlARQFT6fDtxpZo8B3pCLS+uRd5bx3rKN/OS0QbTxjl6cK2hREsDHku4GzgWektQ84udcmdm0fRc3PzOHw3q15z+/0D3ucJxzNYiyIz8XeBY41cw2AB2AH+YzKFecbp/8IWu37uSGsw7yC7/OFYEo5+j7AU+a2Q5JQ4GDgQfzGZQrPh+u3MwDbyzivMN78vke3tGLc8UgyhnA34EqSf2Be4E+wMN5jcoVFTNj7BMzadWssXf04lwRiZIAqs1sN/AV4LdmdjXBWYFzADw7cyWvz1/LNSftT8c2zeMOxzkXUZQEsEvS14ALgUSvXU3zF5IrJtt3VfGLJ2cxsGtbzj+qd9zhOOdqIUoC+AZwNDDOzBZK6gN4q50OgLtfXsCy9d7Ri3PFqMZ/rJnNAn4AzJB0ELDMzG7Ke2Su4C1bv407pszn9IP34+h+3tGLc8WmxlpAYc2fB4BFgICeki4ys1fyGpkreL98ajYSXDf8gLhDcc7VQZRqoLcCJ5vZXABJ+wN/AobkMzBX2N6Yv4anZqzg+yftT/f23tGLc8UoSqFt08TOH8DM5uEXgctaoqOXnh1a8q3j+8YdjnOujqKcAUyVdC8wIXw/Asjaj68rbRPeXMy8lVsYf8EQ7+jFuSIWJQFcDnwHuJLgGsArwB35DMoVrjVbdvCbyfM4bkAnThrcNe5wnHP1UGMCMLMdwG3hw5W5W56Zy6c7qxhzpnf04lyxy5gAJM0ALNN4Mzs4LxG5gvXe0g38ddpSLj22D/27tI07HOdcPWU7AzijwaJwBa+6Omjvp2Pr5ow60Tt6ca4UZEwAZra4IQNxhe0f737Mu0s2cMs5B9O2hVcCc64U+L37rkabt+/ipmfmcGjP9vzXYT3iDsc5lyPeZ5+r0e9enM/qzTv444WVNGrkF36dKxV+BuCymr9qC/e9tpBzK3twaM/2cYfjnMuhGhOApGMlPS9pnqQFkhZKWlDfBUvqKeklSbMlzZQ0qr7zdLllZtw4aRYtmzXmR6cOijsc51yORSkCuhe4muDu36ocLns38H0ze0dSW2CapOfD1kddAZg8exWvzFvN9WcMppN39OJcyYmSADaa2dO5XrCZfQJ8Er7eLGk20B3wBFAAtu+q4ueTZjGgSxsuONo7enGuFEVJAC9JugV4FNiRGGhm7+QqCEkVwBeAt9KMGwmMBOjVq1euFulqcM+rC1iybhsTv3kkTb2jF+dKUpQEcGT4XJk0zIATchGApDYEHc9fZWabUseb2XhgPEBlZWXGO5Nd7izf8Cl/eOkjTjtoX47t3ynucJxzeRKlLaAv5WvhkpoS7Pwnmtmj+VqOq51fPjWbajNGn+4dvThXyrK1BXS+mT0k6Zp0482sXo3DKWhJ7F5gdn3n5XLnzY/WMun9Txg1bAA9Ptcq7nCcc3mU7Qygdficr1a/jgUuIOhreHo47DozeypPy3M12F1VzQ1PzKR7+5ZcPrRf3OE45/IsW1tAd4fPN+RjwWb2GkH/Aq5ATHxrCXNWbOau8w/zjl6cKwNevcMBsG7rTm59bi7H9u/IKQfuG3c4zrkG4AnAAXDLs3PZurOKMWce6B29OFcmPAE4Pvh4I3/+9xIuOrqC/bt6Ry/OlYsobQF1lXSvpKfD94MlXZr/0FxDMDPGPD6Tjq2bcdVJ3tGLc+UkyhnA/cCzQLfw/TzgqjzF4xrYP6d/zLTF6/nRKYPYxzt6ca6sREkAnczsr0A1gJntJreNwrmYbNmxm189NYdDerTjnCHe0Ytz5SZKUxBbJXUk7CBe0lHAxrxG5RrE7178kFWbd3D3BUO8oxfnylCUBHAN8DjQT9LrQGfgnLxG5fJuweqgo5dzhvTgC70+F3c4zrkY1FgEFLb6+R/AMcC3gQPN7P18B+byJ9HRS4smjbk2zx29TJwIFRXQqFHwPHFiXhfnnKuFGs8AJH0lZdD+kjYCM8xsVX7Ccvn04pxVTJm7mp+efgCd2+avo5eJE2HkSNi2LXi/eHHwHmDEiLwt1jkXUZSLwJcC9wAjwscfCYqFXpd0QR5jc3mwY3cVN06aRf8ubbjomIq8Lmv06D07/4Rt24Lhzrlo8nkWHeUaQDVwgJmthOC+AOBOgn4CXgEm5C4cl2/3vraQxWu3MeHSI/Le0cuSJbUb7pz7rHyfRUfZA1Qkdv6hVcD+ZrYO2FX/EFxDWbFxO79/cT4nD+7KcQM65315mTpw847dnIsm32fRURLAq5ImSbpI0kXAY8ArkloDG3IThmsIv3p6NrurjZ+ePrhBljduHLRK6VKgVatguHOuZvk+i46SAL5DcDfwoQT99j4IfMfMtuaztzCXW28vXMdj05dz2fF96dWxYTp6GTECxo+H3r1BCp7Hj/cLwM5Fle+z6CjVQM3MHjGzq83sqvC1981bRKqqg/Z+urVrweVD+6edJtuFpigXoTJNM2IELFoE1dXBs+/8nYsu72fRZpb1AXwF+JDg7t9NwGZgU02fy8djyJAh5mrvwTcXWe9rJ9mk95anHf/QQ2atWpnBnkerVsHwbOOyfV4yu/zyBlpB50rYQw+Z9e4d/Kd69/7sfy8qYKql2afKajiYlzQfONPMZuco59RZZWWlTZ06Ne4wisr6rTv50q1TGLRvW/70raN4+GExenRQhtirV3AkMWoUrF2792d79w6eFy9OP27RouB1RUX6aSSYMMGP+p2Lm6RpZla51/AICeB1Mzs2b5HVgieA2vvpP2fwp7eX8uSVX2TaC/t8pkoZQLNmsHNn+s8m+oVJt4lIQbEOBMU+mTaj5EThnItHpgQQ5T6AqZL+AvwT2JEYaGaP5i48lw///ftt3DS6H1WbD+LUh8WWLXtXKcu084c9F5rSHd0nX4Tq1Sv9NOB1/p0rZFESwD7ANuDkpGEGeAIoYCeeaLzwQksgOIzPtIPOJnGhKfWsoVUrGD48KPpZsgQ6dMg8D6/z71zhqjEBmNk3GiIQlxsTJ8IFFySKZOrexHPr1p8tu0++bjB8ODzwwJ6ksHZtUAyUKBJK8Dr/zhW2KI3BtSBoD+hAoEViuJldkse4XB1MnAjnn594V7/2/bduhSuugDvuCBJBcjKoqNi7KKm6Gjp2hDZtPnuB2S8AO1e4otwINgHYFzgFeBnoQVAV1BWYPTv/3LjrrvR1/jOV669b53X+nSsmURJAfzP7GbDVzB4ATgc+n4uFSzpV0lxJ8yX9OBfzLFfKQ4deZunbHPE2fpwrDVESQKLBtw2SDgLaARX1XbCkxsAfgNOAwcDXJDVMIzUlpr47/44dM49Ld7Tvbfw4VxqiJIDxkj4H/Iyga8hZwM05WPYRwHwzW2BmO4E/A2fnYL5l5Yor6vf53r3h9tszJ5F0R/Xexo9zpSFKLaB7wpcvA31zuOzuwNKk98sI+hj4DEkjgZEAvbyM4TMmToQ776z75xNH7SNGwOuvB2X+yTd0ZTuqT70w7JwrPjWeAUhqLunrkq6TdH3ikYNlpzvm3Ot+UjMbb2aVZlbZuXP+27AvVMmNrXXqBG3b1u+ib+PGnz1qv+OOoNmGbEf13r+vc6Ulyo1gjxE0BDeNpDuBc2AZ0DPpfQ9geQ7nXzJSewVK125PbbRqlb7IJttRvffv61zpiXINoIeZ/T8zu9nMbk08crDsfwMDJPWR1Aw4j+Aag0sxatTe9e6jatwYhg2rXXl9uiN979/XudIT5QzgDUmfN7MZuVywme2W9F3gWaAxcJ+ZzczlMkrBxIl1P+KvS68NmY70MyUgb+vHueKVMQFImkFQJt8E+IakBQRFQCLoJ+bg+i7czJ4CnqrvfErZqFF1+9zgOlaozXSk37gxVFXtPb1fl3eueGU7AzijwaJwGdXl6L99e5hZx3OpTEf0VVXBtYPURuG87r9zxSvjNQAzW2xmi4H9gHVJ79cRNA3hCtDgwbB+fd0/n+mIPnHtwOv+O1c6olwEvhPYkvR+azjMNYBsd+mmU9cj/4Rsd/l6/77OlZYoCUCW1G2YmVUT7eKxy4Hbbw967Yoi6nTZ+F2+zpWPKAlggaQrJTUNH6OABfkOrJwlV8McPRqOOy64CBsw0twvB8CuXWkH15of6TtXHqIkgMuAY4CP2dNcw8h8BlXOEtUwFy8OqnEuXgwvvJBcAydzy29eI8c5Vxs1JgAzW2Vm55lZFzPramZfN7NVDRFcOUpXDXNvQvrsWUCx18jxZiaca3hRzgBcA4p6Y5WZSqacPt1Zz8iRngScyzdZXW4XjUllZaVNnTo17jDyqqIiWgfuvXsH5fOlINM6l9I6OhcnSdPMrDJ1uJ8BFJjhw2ueptiLe1JlOuvxZiacy69sTUFck+2DZnZb7sNxT2VqGEPVgOjdSyXX2XqvXunPAPyitnP5la0+f9vweSBwOHta6jwTeCWfQZWzjEe9JrbvqqZ5k8YZJihe48bt3eBcqZ3lOFeIMiYAM7sBQNJzwGFmtjl8Pxb4W4NEV4Y6dEjf/k+XbqW584c9ZzOjRwcJsFcvSu4sx7lCFOWO3l7AzqT3O8lBp/BubxMnwqZNew9v1KSa224uzZ1/gncx6VzDi3IReALwtqSxksYAbwEP5jes4pOLeuyjR6e/m3eftjXvHL0evXOutqJ0Cj9O0tPAceGgb5jZu/kNq7jkqrvETOX/Gzdkz9PeXaNzri6iVgNtBWwys9uBZZL65DGmopOr7hIz1XqpqTaMd9fonKuLGhNAWOxzLfCTcFBT4KF8BlVsclWPPVtTzA2xfOdceYlyBvCfwFkE/QBgZsvZU0XUUfcj91QjRsDtv99N03afAkavXhapiYdcLd85V16iJICdYX8ABiCpdX5DKj51PXJP5+NOc+l++YvMWLaJxYsVqQw/l8t3zpWPKAngr5LuBtpL+hYwGbgnv2EVl1x1ojJv5WYefHMxXz+iFwd1b9fgy3fOlZcozUH/GngE+DvBXcHXm9n/5DuwYpPaiQrsqZbZqVPwyFZF08wY+/hM2jRvwvdPHljv5fvO3zlXkxqrgUr6bzO7Fng+zTCXRmq1zOQ7ezNV0XzmgxW88dFabjz7QDq0zkHfjs45V4MoRUAnpRl2Wq4DKSU1deqSWkXz051V/OLJ2Qzaty1fP8Kv3DrnGka21kAvB64A+kl6P2lUW+CN+ixU0i0EjcrtBD4iuLlsQ33mWUiiVL9Mnuaulz/i4w2f8ueRR9GksbfQ7ZxrGNn2Ng8T7KQfC58TjyFmVt8S5ueBg8zsYGAee+4xKAlRql8mplm6bht3vfwRZx7SjaP6dsxvYM45lyRjAjCzjWa2CLgdWGdmi81sMbBL0pH1WaiZPWdmu8O3/wJ61Gd+hSZdtcxkyVU0xz05m0YS1w0f1DDBOedcKEp5w53AlqT3W8NhuXIJ8HSmkZJGSpoqaerq1atzuNj8Sa2W2bFj8Eitovnah2t4ZuYKvntCf/Zr1zLusJ1zZabGPoElTTezQ1OGvR8W32T73GRg3zSjRpvZY+E0o4FK4CsWoXPiUuoTeFdVNafd/io7d1fz3NXH06JpaTf37JyLT6Y+gaP0B7BA0pXsOeq/AlhQ04fM7MQaAroIOAMYFmXnX2oeeGMR81dt4Z4LK33n75yLRZQioMuAY4CPgWXAkcDI+ixU0qkEDcydZWZZKkyWptWbd3D75A8ZOrAzww7oEnc4zrkyFaU/gFXAeTle7u+B5sDzkgD+ZWaX5XgZBevmZ+awfXcVPztjMOH6O+dcg8t2H8CPzOxmSb8jbAgumZldWdeFmln/un622E1fuoG/TVvGt4/vS7/ObeIOxzlXxrKdAcwOn0vjqmsBqK42xjw+k85tm/O9YQPiDsc5V+YyJgAzeyJ8fqDhwiltj7yzjPeWbuC2cw+hTfMo19+dcy5/shUBPUGaop8EMzsrLxGVqE3bd3HzM3M4rFd7vnxo97jDcc65rEVAvw6fv0JQnz/RDeTXgEV5jKkk/c/kD1m7dSf3XXw4jRr5hV/nXPyyFQG9DCDp52Z2fNKoJyS9kvfISsj8VZu5/41FnHd4Tw7u0T7ucJxzDoh2H0BnSX0TbyT1ATrnL6TSEnT0MotWzRrzgzp09OKcc/kS5Urk1cAUSYm7fyuAb+ctohLz3KyVvDZ/DWPPHEzHNs3jDsc55/5PlBvBnpE0AEg0VznHzHbkN6zSsH1XFT+fNIuBXdty/lG94w7HOec+o8YiIEmtgB8C3zWz94Beks7Ie2Ql4O6XF7Bs/aeMOWuwd/TinCs4UfZK/0vQc9fR4ftlwC/yFlGJWLZ+G3dMmc/pn9+PY/p1ijsc55zbS5QE0M/MbgZ2AZjZp4DXY6zBL5+ajQTXnX5A3KE451xaURLATkktCW8Kk9QP8GsAWbwxfw1PzVjBFUP70729d/TinCtMURLAGOAZoKekicALwI/yGlUR21VVzdgnZtKzQ0tGHt+35g/k0MSJUFEBjRoFzxMnNujinXNFJmstIEmNgM8R3A18FEHRzygzW9MAsRWlh/61mHkrt3D3BUMatKOXiRNh5EjYFvausHhx8B6C7iedcy5V1jMAM6smqP2z1syeNLNJvvPPbO2WHdz2/DyOG9CJkwd3bdBljx69Z+efsG1bMNw559KJUgT0vKQfSOopqUPikffIitAtz87l051VjDmz4Tt6WbKkdsOdcy7KncCXhM/fSRpmQMMWcBe4Gcs28pepS7n02D7079K2wZffq1dQ7JNuuHPOpVPjGYCZ9Unz8J1/kqCjlw/o2Lo5o06Mp6OXceOgVavPDmvVKhjunHPpRLkTuIWkayQ9Kunvkq6S1KIhgisW/3j3Y95ZsoFrTx1I2xZNY4lhxAgYPx569wYpeB4/3i8AO+cyi1IE9CCwGfhd+P5rwATgq/kKqphs3r6Lm56Zw6E92/Nfh/WINZYRI3yH75yLLkoCGGhmhyS9f0nSe/kKqNj87sX5rN68g3surPSOXpxzRSVKLaB3JR2VeCPpSOD1/IVUPOav2sJ9ry3k3MoeHNKzfdzhOOdcrUQ5AzgSuFBSokJhL2C2pBmAmdnBeYuugJkZN06aRcumjfnhKYNq/oBzzhWYKAng1LxHUYQmz17FK/NW87MzBtO5rXf04pwrPlE6hElTuzw3JP0AuAXoXEx3GCc6ehnQpQ0XHu0dvTjnilOUM4C8kNQTOAkountV73l1AUvWbWPiN4+kqXf04pwrUnHuvX5D0KqoxRhDrS3f8Cl/eOkjTjtoX47t7x29OOeKVywJQNJZwMdhF5M1TTtS0lRJU1evXt0A0WX3y6dmU23GdcO9oxfnXHHLWxGQpMnAvmlGjQauA06OMh8zGw+MB6isrIz1bOHNj9Yy6f1PGDVsAD07tKr5A845V8DylgDM7MR0wyV9HugDvBe2mNkDeEfSEWa2Il/x1NfuqmpueGIm3du35PKh/eIOxznn6q3BLwKb2QygS+K9pEVAZaHXAnr47SXMWbGZO0cc1qAdvTjnXL54FZYI1m3dya3PzePY/h059aB0pVrOOVd8YqsGmmBmFXHHUJNfPzeXLTt2M+bMAxu8oxfnnMsXPwOowQcfb+RPby/hwqN7s3/Xhu/oxTnn8sUTQBZmxtjHZ9KhVTOuOnH/uMNxzrmc8gSQxWPTlzN18Xp+dOpA2rWMp6MX55zLF08AGWzZsZtfPjWbg3u046tDesYdjnPO5VzsF4EL1e9fnM+qzTu4+4Ih3tGLc64k+RlAGgvXbOXe1xZwzpAefKHX5+IOxznn8sITQBo3PjGT5k0a86NTB8YdinPO5Y0ngBQvzlnJS3NXM2rYALq0bRF3OM45lzeeAJLs2F3FjU/Mol/n1lx0TEXc4TjnXF75ReAk97y6kEVrtzHh0iNo1sRzo3OutPleLrRi43b+8NJ8Th7cleMGdI47HOecyztPAKFfPT2b3dXGT08fHHcozjnXIDwBAG8vXMdj05dz2fF96dXRO3pxzpWHsk8AVdXGmMdn0q1dCy4f2j/ucJxzrsGUfQJ4+O0lzP5kE6NPH0zLZt7Ri3OufJR1Ali/dSe3PjeXo/p2YPjnvaMX51x5KesEcOvzc9m8fTdjz/KOXpxz5adsE8DM5Rt5+K0lXHBUbwbtu0/c4TjnXIMrywRgZtzw+Czat2rG1d7Ri3OuTJVlAnj8veW8vWgdPzxlIO1aeUcvzrnyVHYJYOuO3fzqqTl8vns7zq30jl6cc+Wr7NoC+sNL81mxaTt/GHEYjb2jF+dcGSurM4BFa7Zyz6sL+cph3RnS2zt6cc6Vt7JKAD+fNIumjcWPTx0UdyjOORe72BKApO9JmitppqSb8728l+as4oU5qxh14gC67OMdvTjnXCzXACR9CTgbONjMdkjqks/l7dxdzY2TZtG3U2suPqZPPhflnHNFI64zgMuBm8xsB4CZrcrnwu57fSEL12zl+jMHe0cvzjkXimtvuD9wnKS3JL0s6fBME0oaKWmqpKmrV6+u08K6tG3OuZU9GDowrycazjlXVGRm+ZmxNBlI18LaaGAc8CIwCjgc+AvQ12oIprKy0qZOnZrrUJ1zrqRJmmZmlanD83YNwMxOzBLM5cCj4Q7/bUnVQCegbof4zjnnai2uIqB/AicASNofaAasiSkW55wrS3HdCXwfcJ+kD4CdwEU1Ff8455zLrVgSgJntBM6PY9nOOecCXifSOefKlCcA55wrU54AnHOuTHkCcM65MpW3G8HyQdJqYHGeF9OJ4q6SWuzxQ/Gvg8cfv2Jfh1zH39vMOqcOLKoE0BAkTU13x1yxKPb4ofjXweOPX7GvQ0PF70VAzjlXpjwBOOdcmfIEsLfxcQdQT8UePxT/Onj88Sv2dWiQ+P0agHPOlSk/A3DOuTLlCcA558qUJ4A0JI2V9LGk6eFjeNwxRSHpVElzJc2X9OO446ktSYskzQi/86Lo+UfSfZJWhS3bJoZ1kPS8pA/D58/FGWM2GeIvmu1fUk9JL0maLWmmpFHh8GL6DTKtQ95/B78GkIakscAWM/t13LFEJakxMA84CVgG/Bv4mpnNijWwWpC0CKg0s6K5gUfS8cAW4EEzOygcdjOwzsxuChPx58zs2jjjzCRD/GMpku1f0n7Afmb2jqS2wDTgy8DFFM9vkGkdziXPv4OfAZSOI4D5ZrYgbG77z8DZMcdU8szsFWBdyuCzgQfC1w8Q/JkLUob4i4aZfWJm74SvNwOzge4U12+QaR3yzhNAZt+V9H54ilywp49JugNLk94vo4E2ohwy4DlJ0ySNjDuYeuhqZp9A8OcGusQcT10U2/aPpArgC8BbFOlvkLIOkOffoWwTgKTJkj5I8zgbuBPoBxwKfALcGmesESnNsGIr3zvWzA4DTgO+ExZPuIZXdNu/pDbA34GrzGxT3PHURZp1yPvvEFeXkLHL1ml9Mkl/BCblOZxcWAb0THrfA1geUyx1YmbLw+dVkv5BUKz1SrxR1clKSfuZ2Sdh+e6quAOqDTNbmXhdDNu/pKYEO86JZvZoOLiofoN069AQv0PZngFkE24wCf8JfJBp2gLyb2CApD6SmgHnAY/HHFNkklqHF8CQ1Bo4meL43tN5HLgofH0R8FiMsdRaMW3/kgTcC8w2s9uSRhXNb5BpHRrid/BaQGlImkBw2mXAIuDbifLEQhZWE/st0Bi4z8zGxRtRdJL6Av8I3zYBHi6G+CX9CRhK0HzvSmAM8E/gr0AvYAnwVTMryAutGeIfSpFs/5K+CLwKzACqw8HXEZShF8tvkGkdvkaefwdPAM45V6a8CMg558qUJwDnnCtTngCcc65MeQJwzrky5QnAOefKlCcAV/AkDZV0TD3nsaUW094v6Zz6LC9XJL1Ry+kLJnZX+DwBuGIwFKhXAihWZlaW6+0ahicAFwtJ/wwbfZuZ3PBb2KfBO5Lek/RC2DjWZcDVYZvox6Ue5SaO7iW1CT/zjoJ+BWpsDVXShWFjW++FNwAmHC/pDUkLEsvKNH9JFWFb7n8M1+c5SS3DcYeH839T0i0K292X1Dh8/+9w/LczxJdYt6GSpkh6RNIcSRPDO0izrdswSe+Gsd4nqXk4/CZJs8Ll/joc9tWwLaz3JBVj8xuuLszMH/5o8AfQIXxuSXCLe0egM0GLpn1SphkL/CDps/cD5yS93xI+NwH2CV93Auaz52bHLWliOBCYC3RKWd79wN8IDpAGEzSznXH+QAWwGzg0HPdX4Pzw9QfAMeHrm4APwtcjgZ+Gr5sDUxPrnRJjYt2GAhsJ2nhqBLwJfDHN9PcD5wAtwu9y/3D4g8BVQIdwnRPfS/vweQbQPXmYP0r/4WcALi5XSnoP+BdBI3YDgKOAV8xsIYDV/tZ9Ab+U9D4wmaA57K5Zpj8BeMTCDmhSlvdPM6u2oEOdxDyyzX+hmU0PX08DKiS1B9qaWaIc/+Gk+Z8MXChpOkGzBR0JvoNs3jazZWZWDUwnSDyZDAxjmhe+fwA4HtgEbAfukfQVYFs4/nXgfknfImhKxJWBsm0N1MVH0lDgROBoM9smaQrBEauI1oT1bsLiy7AYpFk4fATBWcQQM9uloIexFtlCybK8HSnT1TT/5OmrCM5sshXRCPiemT2bZZpsMVWR/f+bdtlmtlvSEcAwggYDvwucYGaXSToSOB2YLulQM1tbi9hcEfIzABeHdsD6cOc/iODIH4Jijf+Q1AeCfl3D4ZuBtkmfXwQMCV+fDTRNmu+qcOf8JaB3DXG8AJwrqWPK8rLFHXn+ZrYe2CwpsX7nJY1+Frg8bAYYSfuHraDmyhyCs5D+4fsLgJcVtDnfzsyeIigSOjRcfj8ze8vMrgfW8NmmxV2J8jMAF4dngMvCopS5BMVAmNnq8ILwo5IaEbThfhLwBPBIeNH1e8AfgcckvU2wE98aznci8ISCDuWnE+wEMzKzmZLGEewYq4B3CfqSzaRW8w9dCvxR0lZgCkE5PsA9BEU474RnMavJYbeFZrZd0jeAv0lqQtBc+F0E1wAek5Q447o6/MgtkgaEw14A3stVLK5weWugzuWRpDZmlqjJ82OCzr9HxRyWc4CfATiXb6dL+gnBf20x2c8wnGtQfgbgnHNlyi8CO+dcmfIE4JxzZcoTgHPOlSlPAM45V6Y8ATjnXJn6/7ckI23nCZzMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "limit_xy = 6\n",
    "x = np.linspace(-limit_xy, limit_xy)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.plot(act_infl, pred_infl, 'o', color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on Cora dataset')\n",
    "plt.title('Influence function on Citeseer dataset - Node Feature')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7511b006",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.8682848939411696, pvalue=5.72264640046934e-50)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import scipy\n",
    "scipy.stats.spearmanr(act_infl, pred_infl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a6dd84ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "88.69704634731524"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ori_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d966e25c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([pred_infl, act_infl]).T\n",
    "df.columns = ['predicted influence', 'actual influence']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "235d1ffe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df.to_csv('wiki-cs-dataset/' + 'node_feature_influence' +'.csv', index = False)\n",
    "# df.to_csv('amazon_dataset/' + 'computer_node_feature_influence.csv', index = False)\n",
    "df.to_csv('amazon_dataset/' + 'photo_node_feature_influence.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f770b2b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # df = pd.read_csv('hyper_parameter/cora.csv')\n",
    "# pd.read_csv('node_feature/cora.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
