{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0fce7c48",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import os\n",
    "import dgl\n",
    "from dgl import function as fn\n",
    "from dataset import load_graph_dataset\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "import tensorflow.compat.v1 as tf\n",
    "from graph_neural_networks import SGC_layer1, SGC_layer2\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from scipy.special import softmax, log_softmax\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from tqdm import tqdm\n",
    "import cupy as cp\n",
    "\n",
    "from dgl.data import RedditDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "43006593",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "# dataname = 'pubmed'\n",
    "# l2_regularlization_term = 0.1\n",
    "\n",
    "dataname = 'cora'\n",
    "l2_regularlization_term = 0.01\n",
    "\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "558b8ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e8e53bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from generate_mid_layer_feature import FeatureExtraction\n",
    "\n",
    "# FeatureExtractor = FeatureExtraction(num_layers=2, num_iter=100, lr=0.02, hidden_feat=20, device='cpu',\n",
    "#                                          dataset='pubmed')\n",
    "# FeatureExtractor.extract_feature()\n",
    "# train_x, train_y, val_x, val_y, test_x, test_y = FeatureExtractor.preprocessed_data(save='feat.csv')\n",
    "\n",
    "\n",
    "# train_x = train_x.astype(np.float64)\n",
    "# test_x = test_x.astype(np.float64)\n",
    "# train_y = train_y.astype(np.float64)\n",
    "# test_y = test_y.astype(np.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c15b5323",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d9342378",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "test_y = labels[test_mask].numpy().astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4fbb1154",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg = l2_regularlization_term, fit_intercept=True)\n",
    "lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "857d813a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, \n",
    "                                             one_hot_labels_train, l2_reg=True)\n",
    "val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, \n",
    "                                                   one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "# hess = fast_hess_cuda(train_x, logits_train_y)\n",
    "hess = lr.hess_cuda(train_x, logits_train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "569fedf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=True)\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad.T, cholskey=False)\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "\n",
    "pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "pred_infl = list(pred_infl.reshape(-1))\n",
    "#\n",
    "num_train = len(train_x)\n",
    "act_infl = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "deddcb70",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(num_train):\n",
    "    lr_new = SimplifiedGraphNeuralNetwork(l2_reg=l2_regularlization_term, fit_intercept=True)\n",
    "    train_x_new = np.delete(train_x, i, axis = 0)\n",
    "    train_y_new = np.delete(train_y, i)\n",
    "    lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "    logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "    new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test, l2_reg = True)\n",
    "    act_infl.append(new_ori_val_loss - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f0b26013",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEWCAYAAACNJFuYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzgUlEQVR4nO3dd5wU9f3H8dfnqFIsFAsCRxFRNIqKYg9GYy/RqNGgYiWWRGyJLSqaH4kaozEmJqLBgqeCxm5iTRS7goIURRFpilTp0u4+vz++s7rs7e7t3W29ez8fj33s7szszGe2zGdnPjPfr7k7IiIi8coKHYCIiBQfJQcREalGyUFERKpRchARkWqUHEREpBolBxERqUbJIY6ZbWFmY8xsuZn9ycyGmtmDhY4rU4nx53nZK8ysRz6XmW1m9h8zG1ToOOrLzNzMtil0HMXOzGaY2UGFjqNYNfjkUMsvwGBgIbCxu1+aw7ByJS/xm9mrZnZ2/DB3b+Pu03O1zGwws+ZRwv/MzFZG340RZtYNwN0Pc/f7o2lPN7M3ChpwjplZtyiRNC325ZjZgGgef0sY/oaZnV7vIGsfS1X0hyh2eyYL85yTrRizocEnh1oqB6Z46V4ZWOrx59pjwNHAz4FNgJ2BccCBhQwqU7neiJeAlcBpsWReYF9Ff4hit6MKGUxOvhvu3qBvwAzgoOjx6cAbwC3AN8AXwGHRuPuAdcBaYAVwEDAUeDAaPwCYk2beZcAVwOfAImA00C4a1w1wYBAwi/Dv/uq4+TQBropeu5ywweoSjdsOeAlYDEwFTkyxnsnivw/4v7hpNliHKP7LgI+ApcAooGXc+GOA8cCyKLZDgWFAJbA6Ws5fo2kd2CZ6vAnwALAAmAn8Fiir6TNIsV7bA68CS4DJwNEJ6/w34LnofXsX6JliPgcB38be1xTTvAqcHS1zdbSeK4Al0fgWUdyzgHnAP4CNonEdgGejOBcDr8etcyfgX9H78QVwYdwyM/nenBUtc0yKuH8NzAW+As5M+CyOAD6MPsPZwNC4182Kpl0R3fYCegL/jWJZCFQAm8a95nLgy+j9ngocmMF6VFtOHX7HA4A5wB3AvXHD3wBOj4vht4Tv3HzCd3CTuGlPjcYtAq4mw99vqlhSjNsTeCv6HkwABsSNOwP4OHrvpgO/iIa3Jnw3q+Leo05k9vu9nPD7XQM0Tbf8Wr/n9dnwlsKN6slhHXAOYYN8XvSDsmh84ocxlMyTw0XAO0BnwkbkLuBh3/BHfjewEeEf6xpg+7gf90SgN2DR+PbRl2Z29KVqCuxK+MHukGJdE+PP5Mv1XvRFbBd9cc+Nxu1BSBg/Jvxwtga2i8a9CpydsOz4DdIDwFNA22jdPwXOyuQzSJhnM2AaIXE2B35E+GH1jlu/xVGsTQkbskdSvDc3Aq/V8F35br2iON9IGP9n4OnovWoLPAP8IRr3B0KyaBbd9os+yzJCsr82WocehA3DIbX43jwQfRc2ShLzoYREtWM0zUMJn8UA4AdRHDtF0/4kYf5N4+a3TfSZtwA6AmOAP0fjehO+j53iXt+zFuvRNN37X8NnM4CQHLYkJLrYdyA+OZwZfV96AG2Ax4GR0bg+hI3u/lF8twLryeD3myqWJMO3JiSWw6P3+8fR847R+CMIydeAHwKrgF3TbF/uo+bf73igC2G7knb5tX7P6/phlcqN6slhWty4VtGXdssUH8ZQMk8OHxP9i4qeb0XYCDaN+3F0jhv/HnBS9HgqcEyS2H8GvJ4w7C7guhTrmhh/Jl+uU+Ke3wz8I245t6VYzqukSA6EDf4aoE/cuF8Ar2byGSTMcz/ga6J/4NGwh4n+/Ubrd0/cuMOBT1LEfDcpEkey9SIhORB+0CuJ2zMh/NP+Inp8AyEhbpMwz/7ArIRhVxL9+83we9MjTcwjgBvjnm9LXHJIMv2fY58rGWy0gZ8AH0aPtyH8Iz8IaJYwXSbrUe/kEPc9HRU9jk8OrwDnx72md1wM18Z//oREupYMfr8pYqki/DuP3U4k/IsfmTDtC8CgFOv0JDAk2W8z7vtd0+/3zLjntVp+TbfGeAzz69gDd19lZhD+ZdRXOfCEmVXFDasEtki2bMK/hthyuxB2Z5PNs7+ZLYkb1hQYWe9oU8fUKS6mf9dhfh0I/5Bnxg2bSfhXU22ZNXwGnYDZ7h7/nqacFxu+p4kWETacddWRkMjGRfFCSBhNosd/JPyZeDEaP9zdbyR8hp0SPsMmhMNOkNn3ZnaauDoR9kxi4t93zKw/Ya9pR8Ln0gJ4NNXMzGxz4C+ExNyW8A/0GwB3n2ZmF0XruYOZvQBc4u5fZbgeKZnZirinfdx9VprJbwI+N7OdE4Z3ovr3rmkUQyfi3kd3X2lmi+KmTRf/l0li+MrdOyesw53ACWYWX39oBvwvGn8YcB3he1hG+D5NTLOemYj/bpSnW35tqSCduZWEDxMAM2tC2GDEzCYcO9807tbS3ZN9sRLNJuxuJhv+WsI827j7eXWJmbBLnqlUMUH4F5jKQsI/rvK4YV1J/gOryVdAFzOL/57WdV4vA3uYWecapwwS13Eh4bjwDnGfxSbu3gbA3Ze7+6Xu3gM4CrjEzA4kvI9fJHyGbd398Gi+mXxv0r3fcwmJPKZrwviHCIfCurj7JoRDX7Hslmy+f4iG7+TuGwOnxE2Puz/k7vsSPl8nbKhrWo908cfmG1/cTZcYcPdFhD2g3yWMiiWpmK6EQ0fzSHifzKwV4dBtTH1+v/HzGJkwj9bufqOZtSDUnW4BtnD3TQl/vtJ9Fpn8fuNfl3L5tViH7yg5ZO5ToKWZHWFmzQiFrxZx4/8BDDOzcgAz62hmx2Q473uA35lZLwt2MrP2hALntmZ2qpk1i267m9n2Gc53PHC4mbUzsy0Jx1Uz9U/gDDM70MzKzGxrM9suGjePcFy3GnevJBTzhplZ2+j9uASoy/Ui7xJ+IL+J1n0AYcP7SG1n5O4vEwr7T5jZbmbWNIrvXDM7M8lL5gGdzax59PoqwqGp26J/10TvySHR4yPNbBsLuw3LCP86KwmHD5eZ2eVmtpGZNTGzHc1s92g59fneQHivTzezPtEG77qE8W2Bxe6+2sz2IJypFbOAcHikR8L0K4AlZrY1oR5GFFtvM/tRtKFbTUiWlRmsR7Ll1NetwN6EkwdiHgYuNrPuZtYG+D3h8NN6wplqR5rZvtFnegMbbv/q+zlA+I4fZWaHRJ9zy+gU1c58v9e2AFgf7UUcHPfaeUB7M9skbth4avf7Tbf8WlNyyJC7LwXOJ2zIvyRstOLPS76d8A/tRTNbTihu9c9w9rcSfuQvEjYs/yQUH5cTvkAnEf4VfU34p9YixXwSjSScsTAjmveoDF+Hu79HKITfRihMv8b3/8puB443s2/M7C9JXv4rwvsznXBM+CHCsfFacfe1hFNPDyP8c78TOM3dP6ntvCLHE/6tjSKs0ySgH2GvItF/CWdHfW1mC6NhlxMKnu+Y2bLodb2jcb2i5yuAt4E73f3VKFkeBfQlnKm0kPAdim0E6vO9wd3/Q/gX/d8otv8mTHI+cEM072sJ37PYa1cRzj5708yWmNmewPWEEx+WEs4CezxuXi0Ih6gWEr6LmxNOFki7HimWUy/uvoxQe2gXN3gE4Ts/hvBeryZ8F3H3ycAFhO/iXMKhsmz9fmMxzSac4XcVIQnMJiTXsui3fCHh/f+GkKSfjnvtJ4TkNj16jzpRy99vuuXXZj1iYmfpiIiIfEd7DiIiUo2Sg4iIVKPkICIi1Sg5iIhINQ3iIrgOHTp4t27dCh2GiEhJGTdu3EJ375hsXINIDt26dWPs2LGFDkNEpKSY2cxU43RYSUREqlFyEBGRapQcRESkGiUHERGpRslBRESqUXIQESkhFRXQrRuUlYX7iorcLKdBnMoqItIYVFTA4MGwalV4PnNmeA4wcGB2l6U9BxGREnH11d8nhphVq8LwbFNyEBEpEbNS9JGXanh9KDmIiJSIromdwNYwvD6UHERESsSwYdCq1YbDWrUKw7NNyUFEpEQMHAjDh0N5OZiF++HDs1+MBp2tJCJSUgYOzE0ySKQ9BxERqUbJQUREqlFyEBGRapQcRESkGiUHERGpRslBRESqUXIQEZFqlBxERKQaJQcREalGyUFERKpRchARkWqUHEREpBolBxERqUbJQUREqlFyEBGRagqaHMxshJnNN7NJccPamdlLZvZZdL9ZIWMUEWmMCr3ncB9waMKwK4BX3L0X8Er0XERE8qigycHdxwCLEwYfA9wfPb4f+Ek+YxKR0jVt/nI+mPVNocNoEAq955DMFu4+FyC63zzZRGY22MzGmtnYBQsW5DVAESku7s7Id2ZyxF/e4NqnJuHuhQ6p5BVjcsiIuw93937u3q9jx46FDkdECmTRijWc88A4rnlyEv17tGfE6btjZoUOK+cqKqBbNygrC/cVFdmdf9Pszi4r5pnZVu4+18y2AuYXOiARKU5jPl3ApY9OYOmqdVx7ZB9O37sbZWWNIzEMHgyrVoXnM2eG5wADB2ZnGcW45/A0MCh6PAh4qoCxiEgRWrO+kt89O4XTRrzHZq2a8dQv9+HMfbs3isQAcPXV3yeGmFWrwvBsKeieg5k9DAwAOpjZHOA64EZgtJmdBcwCTihchCJSbD6bt5xfPfwhn3y9nNP2Kueqw7enZbMmhQ4rr2bNqt3wuihocnD3k1OMOjCvgYhI0XN3HnxnJv/33Me0adGUEaf340fbbVHosAqia9dwKCnZ8GwpxsNKIiIbWLhiDWffP5ZrnprMnj3a85+L9mu0iQFg2DBo1WrDYa1aheHZUowFaRGR77z26QIuHT2BZavXcd1RfRi0V+MoOqcTKzpffXU4lNS1a0gM2SpGg5KDiBSp1esqufn5qYx48wu23aIND569B9ttuXGhwyoaAwdmNxkkUnIQkaLz6bzlXBgVnQftVc6VjbDoXGiqOYhI0XB3Hnh7Bkfd8QYLV6zh3tN35/pjdlRiSJDrC+BAew4iUiQWrljD5Y99xCufzGdA74788fid6di2RaHDKjr5uAAOwBpCGyT9+vXzsWPHFjoMEamjV6fO57JHP2LZ6nVcddh2DNq7W6NoAqMuunVLfhpreTnMmFG7eZnZOHfvl2yc9hxEpGBWr6vkpuc/4d43Z9B7i7YqOmcgHxfAgZKDiBTI1K+XM+SRUHQ+fe9uXHHYdqotZCAfF8CBCtIikmfuzv1vzeCov35fdB569A5KDBnKxwVwoD0HEcmjhSvW8JvHPuK/n8zngN4duVlF51rLxwVwkEFyMLMTgOfdfbmZ/RbYFfg/d/8gu6GISEMWX3S+/ugdOG2vchWd6yjXF8BBZoeVrokSw77AIYSuO/+e27BEpKFYva6SoU9P5vR736d96+Y888t9dTZSCcjksFJldH8E8Hd3f8rMhuYuJBFpKKZ+Ha50njpvOWfs043LD1XRuVRkkhy+NLO7gIOAm8ysBSpki0gasaLz7//zCRu3bMq9Z+zOAb2TdgcvRSqT5HAicChwi7svibru/HVuwxKRUrVg+Rp+89gE/jd1AQf07sgfT9iZDm1UdC41mewBbAU85+6fmdkAQs9s7+UyKBEpTf/7ZD6H3T6GNz9fxPVH78CI03dv8IkhH+0cFUImyeFfQKWZbQP8E+gOPJTTqESkpMSKzmfc9z4d2rTg2V8Vvuicj412rJ2jmTPB/ft2jhpCgsjksFKVu683s+OAP7v7HWb2Ya4DE5HSUIxF53w1Tnf11d8vI2bVqjA816ea5lomew7rzOxk4DTg2WhYs9yFJCKlwN25980vOOqvb7Bo5VruO2N3rjuqOK50TrfRzqZ8tXNUCJnsOZwBnAsMc/cvzKw78GBuwxKRYrZg+Roue3QCr326gAO325ybjt+paGoLFRXJ2x6C7G+089XOUSHUuOfg7lOAy4CJZrYjMMfdb8x5ZCJSlP77yTwO/fMY3pm+iBuO2YF7BvXLaWKoqXYQP75DBzjjjNTzKivLbj0gX+0cFUImzWcMIFwVPQMwoIuZDXL3MTmNTESKyup1lfz+3x/zwNsz2W7Ltjw8eE+23aJtTpdZU+0gcfyiRennV1mZ3dpDvto5KoQaO/sxs3HAz919avR8W+Bhd98tD/FlRJ39iOTWx3OXMeSRD/l03oq8Fp1r6tgm1fiapOsYp6KiYW7sk6lvZz/NYokBwN0/NTMVpEUagaoq5763ZnDj85+wcctm3HfG7gzI45XONRV861pDSPW6fJ3lVAoyOVtprJn908wGRLe7gXG5DkxECmv+8tWcft/73PDsFPbbpgMvXLRfXhMDpC7sxobXtfCb6nX5OsupFGSSHM4DJgMXAkOAKYSzl0SkgXrl43kc9ufXeXf6In73kx25Z1A/2hfgbKSaCr7JxjdvDu3bg1m4b9Ys9esTNeRTU2srk7OV1rj7re5+nLsf6+63ufuafAQnIvm1el0l1z41ibPuH8vmG7fk2V/ty6l7Fq7fhYEDYfjwUCMwC/fDh39/iCfZ+BEjYOFCqKoK9/fem/r1iWraU2lMUhakzWwikLJa7e475Sqo2lJBWqT+pnwVis6fzV/B2ft259eH9qZF08Jf0JZPiTUHCHsa6RJKKatrQfrIHMUjIkWkqsq5960Z3PSfT9ikVTMeOHMP9t+2Y6HDKoiGfGpqbaVMDu5ehxPERKSUzF+2msse+4gxny7goO0356af7lSQ2kIxyUcXnKUgk1NZRaQBeuXjefz6sY9YtXY9//eTHRnYv6u67pTvKDmINDLfrg1XOo98Zybbb7Uxd5zcl202z+2VzlJ61N2nSCMy5atlHP3XNxj5zkzO2a87T16wd1ElhobacU4pyqRtpX2AoUB5NL0B7u49chuaiGRLVZUz4s0vuPn5qWzaqhkjz9qD/XoVV9FZVycXl0zaVvoEuJhwVXRlbLi719DEVT0DM5sBLI+WuT7V6VagU1lF0pm/bDWXPjqB1z9byEHbb8HNx+9Eu9bNCx1WNTW1oyTZV9+2lZa6+3+yHFOmDnD3hQVatkjJe3nKPH7zr1B0Hnbsjvx8j+ItOuvq5OKSSXL4n5n9EXgc+O7KaHf/IGdRiUi9fLu2kmH/nsKD78yiz1Yb85cSKDo35I5zSlEmyaF/dB+/6+HAj7IfzgYceNHMHLjL3YfHjzSzwcBggK769oh8Z/JXSxnyyHimzV/BOft157JDSuNK52HDql+d3KwZrFgRCtSN+YK0Qqix5lAoZtbJ3b8ys82Bl4BfpepgSDUHkepF51tP7Mu+vToUOqxaie9LoV07WL4c1q79fnxDbsqiENLVHNK1rXSKuz9oZpckG+/ut2YxxrTMbCiwwt1vSTZeyUEau/ii84/7bMFNPy3OonNtqECde3UtSLeO7vN+oNLMWgNl7r48enwwcEO+4xApBS9NmcdvHpvAt+sqi77oXBsqUBdWyovg3P2u6P76ZLccx7UF8IaZTQDeA55z9+dzvEyRkvLt2kqufmIi5zwwlk6bbsSzv9qPgf3r3rx2ugvQCnFxmprPLqyibD7D3acDOxc6DpFiNenLpQx55EM+X7CSwfv34NKDt61X0TndBWhQmIvTkhWoIRSoKypUd8g1NZ8hUkKqqpy7x0zn2DvfZPnq9Tx4Vn+uOnz7ep+NlK57zFTjhgzJ7d5ErCOf9u03HL5oUUgaNS1PTXHUT9GerVQbKkhLYzBv2Woui4rOB0dF582yVHQuK4Nkm4LYEapMNhO5OpOoLoXpxtZpT13V6WyluBdvAfwe6OTuh5lZH2Avd/9n9kOtGyUHaehenPw1l//rI75dV8m1R+7AyXt0yWrROd0GGJKPSyYXZxKlS1xVVclfozOdMpMuOWRyWOk+4AWgU/T8U+CirEQmIml9u7aSq56YyOCR49h6s1B0/nkO+l0YNiz8s47XqlUYnmxcKrk4k6guhWmd6VR/mSSHDu4+GqgCcPf1xDXAJyK5MenLpRxxx+s8/N4sfrF/Dx4/bx+22bxNTpYVO75fXh7+kZeXf38IJtm4xDpATC7OJEqXuFLRmU71l0lyWGlm7QnNWWBmewJLcxqVSCNWVeUMH/M5x975JivXrKfirP5cefj2NG+a2/NHBg4Mh1yqqsJ9/LH5xHG33177DXZ94kqVuFKpS0KRDWVyKuslwNNATzN7E+gIHJ/TqEQaqXnLVnPp6Am8MW0hh+ywBTcel72iczbFNsyxpi5y3e5Rbft1znd8DVFGZyuZWVOgN6Gjn6nuvi7XgdWGCtLSELwQFZ3XrKvi2qP6cNLu2S06iySqV38OZnZcwqBtzWwpMNHd52cjQJHGbNXa9fzu2Y95+L1Z7Lj1xtx+0i707Jib2oJIpjI5iHkWcA8wMLrdTTjU9KaZnZrD2EQavElfLuXIO97gkfdnce4Pe/L4efsULDHoojGJl0lyqAK2d/efuvtPgT6ETn/6A5fnMjiRhqqqyrnrtVB0XrWmkoqz+3PFYdulLTpnsvGu6wY+dtHYzJnhmoJYExlKEI1XJhfBTXT3H8Q9N8IhpR3N7EN33yXXQdZENQcpJV8vXc0lo8fz1ueLOGzHLfnDcT9g01bpi86ZXPGbbBozOPdcuPPO9DHporHGqb5XSN8JdAUejQb9FJgD/Bp41t0PyGKsdaLkIKXi+Ulfc8Xjoeg89Og+nNgvs6JzJhvvVNOYwciR4XGqs3fqchWylL76XiF9AeEq6b7ALsADwAXuvrIYEoNIKVi1dj1X/Osjzn1wHF02a8VzF+7L+k+70r27VTsEFDs0ZAZNm4b7VM1XxF/xm+rqX/fQSF66w0a6aEwSqeE9kRybOCc0r/3FopWc+8OeXHzQtjw6qizpYaJBg+D++6u3gppKkybhn33XrqEp60WLahdb+/bQpk1IFmYb7j3ko6G6+G5BdS1C/tX3sNJxwE3A5oTrHAxwd98424HWlZKDFKOqKmf469P504tTad+6Bbf+bGf27hn6dE51CKhJE6isY+M0zZtv2N9yXcQSRHl57jfUajm18OqbHKYBR7n7x7kILhuUHKTYzF36LZeMmsDb0xdx+A+25PfHfl90rqiAU06p+7zNQo0gWRJp0wZWrqy+B7DRRpnvVeSrCK0ieOHVt+Ywr5gTg0ixeX7SXA798+tMmLOEm3+6E3/7+a4bJIb4HtYSNamhz57y8nAYKVWReOXKUHxObIcoWVtIqeSr5VK1nFrcMmlbaayZjQKeJFzfAIC7P56roERK0co167nhmSmMGjubnTpvwu0n7UL3Dq03mCZZr2oxNdUc4huO69o1+b/url3Tt0MUf3w/VY0iX0XodOsghZfJnsPGwCrgYOCo6HZkLoMSKTUfzVnCkXe8wehxszl/QE/+dd7e1RIDpP9XPHw47LNPOAQUUxb9QhNbIq1Lq6OFbFk1GbWcWuTcveRvu+22m4sUwvrKKr/zf9O855XP+Z6/f9nfmraw2jQPPuheXu5u5t6kiXuoCCS/mW34vFWr8Ppk4udbXp56unSyMY/6KPTyGztgrKfYrmZSkG5JaF9pB6BlXFI5M6dZqxZUkJZCSFd0jqmogDPPrN9ZRCrQSq7UtyA9EtgSOAR4DegMLM9eeCKl5z8TNyw678Ou9O3TvNoFbUOG1P/0UhVopRAyKUhv4+4nmNkx7n6/mT1E6FNapNFZuWY91z8zmdFj57Bz503480m78NYLrTfYO5g5M+wtQO0vSktGBVophEySQ6xjnyVmtiPwNdAtZxGJFKkJs5dw0ajxzFi0kgsO6MlFB21LsyZl7J5k72Dt2tDgXX2pQCuFkslhpeFmthlwDaG70CnAzTmNSqSIVFY5f/vfNH7697eY/V4HvrnrMH5zaGheu02b1HsHK1bUbXmxdvgy6StZJFdq3HNw93uih68BPXIbjkhx+WrJt1w8ajyvPNuSFa/+mDUrmhJakAlWrsz+MkeOVEKQwsukm9AWhGa6u8VP7+435C4skcL798S5nHrmOhaP3TMakp/+nJUYpBhkcljpKeAYYD2wMu4m0iCNuL+STTqu5YidtmTx2C58395k7rVvX/vXqHtPyYVMCtKd3f3QnEciUkCxpqNnznTCf6YaGjnKgWbNwlXLtZHYsmmsnwbQHojUTyZ7Dm+Z2Q9qnkyktMR3qnPqqR6185O/vQTYsPh8772136Ana6tp1aowXKQ+Uu45mNlEwKNpzjCz6YSG92L9OeyUnxBFsi/xH7d7/hJCvFjfCXW9Alotm0qupDuspMb1pMEaMiTz3tZyrT4bcrVsKrmS8rCSu89095nAVsDiuOeLCc1piJScigqiaxOKp3vc+mzI1bKp5EomNYe/A/GX86yMhomUlIoKOO202LUJ+Tv7yCzct67egne9N+QDB4YL5RI791ExWuork+RgHtd0q7tXkdlZTvViZoea2VQzm2ZmV+R6edKwVVSEonOqHtRyoX17WLgw9J+wcGG4YvrBB7O/IU/sp0GJQbIhk+Qw3cwuNLNm0W0IMD2XQZlZE+BvwGFAH+BkM+uTy2VKw1RRAW3bwimneE6Kzq1awXnnQfMNW+qmefPkp6VqQy6lIpPkcC6wN/AlMAfoD6TpBTcr9gCmuft0d18LPEK4EE8kIxUV0KEDnHJKrI2j7CeGNm3CP/8774QRIzbcIxgxQht+KW2ZtK00HzgpD7HE2xqYHfc8lpS+Y2aDiZJUV52aIXEST1PNhfPOC0khJl2/zSKlKJM9h0JI9jdvg9NL3H24u/dz934dO3bMU1hS7CoqYNCg7CeG2B7Bgw+GaxPiE4NIQ5TzwnIdzQG6xD3vDHxVoFikRIQ9BqeyMruHkNRNpzRGxbrn8D7Qy8y6m1lzwmGtpwsckxS5y6+oYtWq7CYGXTMgjVW65jMuSfdCd781++F8N+/1ZvZLQnekTYAR7j45V8uT0hRrLG/WLOiw5XoWzM1uY3nl5SExqJYgjVG6w0pto/vewO58/8/9KGBMLoMCcPd/A//O9XKkNCUWnRfMbUpCWapOWreGu+5SQhBJmRzc/XoAM3sR2NXdl0fPhwKP5iU6kRSStUZa29NVW7eGli1h8eLQhIX2EkS+l0lBuisQ3336WkKvcCIFUVnlzJwFdbl2wUyJQCQTmRSkRwLvmdlQM7sOeBd4ILdhiSTv4Wz24lWcNPxtmrT9ttbzKy/XlckimaoxObj7MOAM4BtgCXCGu/8+x3FJA5Vpl5bnnw+nnhqao3YP92edXcXeZ3/Gx3OX8+ODK2u1XJ11JFI7mV7n0ApY5u73mllHM+vu7l/kMjBpeDLt0rKiAv7xj5AU4q1ZXcaiV3vz5t292H+3hHaqk2jSJOwp6DCSSO2ZJ/4CEycIh5L6Ab3dfVsz6wQ86u775CPATPTr18/Hjh1b6DCkBt26Je+YJvEis1TTAZg5VVVGWVn15BGvVSs1XS1SEzMb5+79ko3LpOZwLHA0oR8H3P0rvj/NVSRjqXo8mzlzw0NNqRIDQLt2Rrdu6ROD+jQQqb9MDiutdXc3MwcwsyRdlojULFWXlmbfD585M+wdpGpe+5tvYNGi5PPX3oJI9mSy5zDazO4CNjWzc4CXgXtyG5Y0RMm6tDSrvhcQEkPyXYNUnfVob0EkuzJpsvsWM/sxsIxwtfS17v5SziOTBie24Y41eZFqTyKmSROozOCkJDM1jCeSbTXuOZjZTe7+krv/2t0vc/eXzOymfAQnDU9iT2jl5cmnKy+3jLv0VHceItmXyWGlHycZdli2A5HGZ31lFfuePBdrun6D4bFrEjLZ6Ov6BZHcSJkczOw8M5sIbGdmH8XdvgAm5i9EaYjClc7v8IZ/wCHnzqFLF/+uQ51Y7SBZjaJ5c2jfnmrTikh2pas5PAT8B/gDcEXc8OXuvjinUUmD9uSHX3LNk5MAuP2kvhzTd2u4o/p0yWoUuphNJD8yuQhuT2ByXKusbYE+7v5uHuLLiC6CKw3LVq/j2icn8eT4r+hXvhm3/awvXdrVfKWziORGuovgMrnO4e/ArnHPVyYZJpLWuJmLGfLIeOYuXc3FB23LBQf0pGmTYu2IUEQySQ7mcbsX7l5lZsXa97QUmfWVVdzx32nc8d/P2HqzjRj9i73YrXyzQoclIjXIZCM/3cwuJOwtAJwPTM9dSNJQzF68iotGjWfczG84dpetueGYHWjbslmhwxKRDGSyX38usDfwJTAH6A8MzmVQUpwybW4bQtH58Ntf59Ovl3P7SX257Wd9lRhESkgmV0jPB07KQyxSxDJtbnvZ6nVc8+Qknhr/Fbt324xbT1TRWaQUpTxbycx+4+43m9kdJGnoxt0vzHVwmdLZSrmXSXPbY2cs5qJRoeg85MBenD9ARWeRYlbXs5U+ju611ZWUzW3PmhWKzn/57zT+GhWdHz13L3btqqKzSClLmRzc/Zno/v78hSPFKlUjeZ22ruKEu97mw1lLOG7Xrbn+aBWdRRqClMnBzJ4hVbvJgLsfnZOIpCgNG7ZhzQGgRcsq2H0i0+av4C8n78LRO3cqXIAiklXpDivdEt0fB2wJPBg9PxmYkcOYpAht2JSF06b9OprvOZm9Dv6W2362H503U9FZpCFJd1jpNQAz+5277x836hkzG5PzyKToDBwI2+6zmIseGc/Xy1Zz0YG9OP+AvjQpS95rm4iUrkwugutoZj3cfTqAmXUHOuY2LCk26yur+Msrn/HX/02j82ateOzcvdhFRWeRBiuT5HAx8KqZxa6K7gb8ImcRSdGZuWglF40az4ezlvDTXTtz/TE70KaFWlARacgyuQjueTPrBWwXDfrE3dfkNiwpBu7O4x98ybVPTaKszLjj5F04SkVnkUahxuRgZq2AS4Bydz/HzHqZWW93fzb34UmhLP12HVc/MZFnP5rLHt3acdtJfdl6040KHZaI5EkmxwbuBcYBe0XP5wCPAkoODdR7Xyzm4lGh6PzrQ3pz7g97qugs0shkkhx6uvvPzOxkAHf/1sy0pWiA1kVF57/9bxpd2qnoLNKYZZIc1prZRkQXxJlZT0A1hwZm5qKVDHlkPONnL+H43Toz9GgVnUUas0x+/dcBzwNdzKwC2Ac4PZdBSf64O//64Euue2oSTcqMv/58F47cSUVnkcYubXIwszJgM8JV0nsCBgxx94W5CsjMhgLnAAuiQVe5+79ztbzGbIOic/d23PYzFZ1FJEibHKIuQX/p7qOB5/IUE8Bt7n5LzZNJXb07fRGXjJ7APBWdRSSJTA4rvWRmlwGjgJWxge6+OGdRSc6sq6zi9pc/485Xp9G1XSv+dd7e7Nxl00KHJSJFJpPkcGZ0f0HcMAd6ZD+c7/zSzE4j9CVxqbt/kziBmQ0m6q60a9euOQyl4ZixcCVDRo1nwuwlnBAVnVur6CwiSaTsCS6nCzV7mdDSa6KrgXeAhYQE9DtgK3c/M8m031FPcOm5O4+Nm8PQpyfTpMz4w3E7ccROWxU6LBEpsLr2BBd7cUvgfGBfwgb7deAf7r66rgG5+0GZTGdmd6OL7epl6ap1XPXkRJ77aC79o6JzJxWdRaQGmRxTeABYDtwRPT8ZGAmckIuAzGwrd58bPT0WmJSL5TQG705fxMWjxjN/+Rp+c2hvfrG/is4ikplMkkNvd9857vn/zGxCrgICbjazvoS9lBmoBdhaixWd//bqNLq1b62is4jUWibJ4UMz29Pd3wEws/7Am7kKyN1PzdW8G4MZC1cy5JEPmTBnKSf268x1R6noLCK1l8lWoz9wmpnNip53BT42s4mAu/tOOYtOMubuPBoVnZs1KePOgbty+A9UdBaRuskkORya8yikXpauWsdVT0zkuYlz2bNHO249UUVnEamfTDr7mZmPQKRu3omKzgtUdBaRLNLB6BK1rrKK2176lL+/9rmKziKSdUoOJeiLhSu5KCo6/6xfF649qo+KziKSVdqilJDEovPfB+7KYSo6i0gOKDmUiCWr1nLVExP598Sv2atHe2792c5stYmKziKSG0oOJeDtzxdxyehQdL780O0YvH8PFZ1FJKeUHIrY2vVV3Pbyp/zjtc/p3r41T5y/Dz/ovEmhwxKRRkDJoUhNX7CCIY+MZ+KXSzlp91B0btVcH5eI5Ie2NkXG3Rk9djZDn55C86YqOotIYSg5FJElq9Zy5eMT+c+kr9m7Z3v+dKKKziJSGEoOReKtzxdyyagJLFq5hisP245z9utBmYrOIlIgSg4FtnZ9Fbe+9Cl3jQlF57tPU9FZRApPyaGAPl+wgouiovPJe3TlmiO3V9FZRIqCtkQF4O6Men821z8zhRbNyvjHKbtx6I7JutQWESkMJYc8+2blWq54/CNemDyPfbZpz59O6MuWm7QsdFgiIhtQcsijt6Yt5JLRKjqLSPFTcsiDteur+NNLUxk+ZjrdO7TmnkH7sOPWKjqLSPFScsix+KLzz/t35bdHqOgsIsVPW6kccXceeX82NzwzhZbNyrjr1N04ZAcVnUWkNCg55EBi0fnWE/uyxcYqOotI6VByyLI3py3kktHjWbxyLVcfvj1n7dtdRWcRKTlKDlmydn0Vf3pxKsNfn06PDq3556DdVXQWkZKl5JAF0+avYMgjHzL5q2X8vH9XrjmiDxs1b1LosERE6kzJoR7cnYffm80Nz05mo2ZNGH7qbhysorOINABKDnUUX3Ter1cHbjlhZxWdRaTBUHKogzc+W8ilj4ai82+P2J4z91HRWUQaFiWHWlizvpI/vfgpw8dMp2fH1ow4fXd26KSis4g0PEoOGYovOg/s35XfqugsIg2YkkMN3J2H3pvF756doqKziDQaSg5pLF65lsv/9REvTVHRWUQaFyWHFN74LFzpvGTVOhWdRaTRUXJIsGZ9Jbe8MJW7X/+CbTZvw71nqOgsIo2PkkOcafNXcOHDHzJl7jJO2bMrVx+uorOINE5lhViomZ1gZpPNrMrM+iWMu9LMppnZVDM7JB/xuDsV787kyDte5+tlq7nntH78309+oMQgIo1WofYcJgHHAXfFDzSzPsBJwA5AJ+BlM9vW3StzFUhi0flPJ+zM5io6i0gjV5Dk4O4fA5hVK/AeAzzi7muAL8xsGrAH8HYu4pg4Zyln3v8+S1et45oj+3DG3t1UdBYRofhqDlsD78Q9nxMNq8bMBgODAbp27VqnhXXebCO227ItVx62PX06bVyneYiINEQ5Sw5m9jKQ7Gqxq939qVQvSzLMk03o7sOB4QD9+vVLOk1NNmvdnJFn9a/LS0VEGrScJQd3P6gOL5sDdIl73hn4KjsRiYhIpgpytlIaTwMnmVkLM+sO9ALeK3BMIiKNTqFOZT3WzOYAewHPmdkLAO4+GRgNTAGeBy7I5ZlKIiKSXKHOVnoCeCLFuGHAsPxGJCIi8YrtsJKIiBQBJQcREalGyUFERKpRchARkWrMvU7XjxUVM1sAzKzHLDoAC7MUTqnQOjcOWufGoa7rXO7uHZONaBDJob7MbKy796t5yoZD69w4aJ0bh1yssw4riYhINUoOIiJSjZJDMLzQARSA1rlx0Do3DllfZ9UcRESkGu05iIhINUoOIiJSTaNNDmZ2gplNNrMqM+uXMO5KM5tmZlPN7JBCxZhLZjbUzL40s/HR7fBCx5QrZnZo9FlOM7MrCh1PvpjZDDObGH2+YwsdTy6Y2Qgzm29mk+KGtTOzl8zss+h+s0LGmG0p1jnrv+dGmxyAScBxwJj4gWbWBzgJ2AE4FLjTzJrkP7y8uM3d+0a3fxc6mFyIPru/AYcBfYCTo8+4sTgg+nwb6nn/9xF+p/GuAF5x917AK9HzhuQ+qq8zZPn33GiTg7t/7O5Tk4w6BnjE3de4+xfANGCP/EYnWbQHMM3dp7v7WuARwmcsDYC7jwEWJww+Brg/enw/8JN8xpRrKdY56xptckhja2B23PM50bCG6Jdm9lG0m9qgdr3jNKbPM5EDL5rZODMbXOhg8mgLd58LEN1vXuB48iWrv+cGnRzM7GUzm5Tklu6foyUZVpLn+9aw/n8HegJ9gbnAnwoZaw41mM+zDvZx910Jh9QuMLP9Cx2Q5EzWf88F6QkuX9z9oDq8bA7QJe55Z+Cr7ESUX5muv5ndDTyb43AKpcF8nrXl7l9F9/PN7AnCIbYx6V/VIMwzs63cfa6ZbQXML3RAuebu82KPs/V7btB7DnX0NHCSmbUws+5AL+C9AseUddGPJuZYQoG+IXof6GVm3c2sOeFkg6cLHFPOmVlrM2sbewwcTMP9jBM9DQyKHg8CnipgLHmRi99zg95zSMfMjgXuADoCz5nZeHc/xN0nm9loYAqwHrjA3SsLGWuO3GxmfQmHWGYAvyhoNDni7uvN7JfAC0ATYIS7Ty5wWPmwBfCEmUH4nT/k7s8XNqTsM7OHgQFABzObA1wH3AiMNrOzgFnACYWLMPtSrPOAbP+e1XyGiIhUo8NKIiJSjZKDiIhUo+QgIiLVKDmIiEg1Sg4iIlKNkoOUNDMbYGZ713MeK2ox7X1mdnx9lpctZvZWLacvmtil+Ck5SKkbANQrOZQqd2+U6y35oeQgRcfMnowai5sc32Bc1C/DB2Y2wcxeMbNuwLnAxVEb9vsl/juO7RWYWZvoNR9EfRzU2DKrmZ0WNWQ2wcxGxo3a38zeMrPpsWWlmr+ZdTOzj83s7mh9XjSzjaJxu0fzf9vM/hhrn9/MmkTP34/GJ72gKW7dBpjZq2b2mJl9YmYVFl39lmbdDjSzD6NYR5hZi2j4jWY2JVruLdGwE6I2uSaYWWNofkMA3F033YrqBrSL7jciNAPQnnAl+2yge8I0Q4HL4l57H3B83PMV0X1TYOPocQdCU+wWP01CDDsAU4EOCcu7D3iU8MeqD6E58JTzB7oRrrTvG40bDZwSPZ4E7B09vhGYFD0eDPw2etwCGBtb74QYY+s2AFhKaDeqDHgb2DfJ9PcBxwMto/dy22j4A8BFQLtonWPvy6bR/URg6/hhujX8m/YcpBhdaGYTgHcIjeb1AvYExnjoYwN3r2179gb83sw+Al4mNNu9RZrpfwQ85u4LkyzvSXevcvcpcfNIN/8v3H189Hgc0M3MNgXaunusbvBQ3PwPBk4zs/HAu4Tk2KuG9XvP3ee4exUwnpCUUukdxfRp9Px+YH9gGbAauMfMjgNWRePfBO4zs3MITZBII9Bo21aS4mRmA4CDgL3cfZWZvUr4p2tk1tT2eqLDpdGhlebR8IGEvY/d3H2dmc2I5psylDTLW5MwXU3zj5++krBHlO6wjwG/cvcX0kyTLqZK0v+2ky7bQztUewAHEhoo/CXwI3c/18z6A0cA482sr7svqkVsUoK05yDFZhPgmygxbEfYY4BwqOSHUUu5mFm7aPhyoG3c62cAu0WPjwGaxc13frThPgAoryGOV4ATzax9wvLSxZ3x/N39G2C5mcXW76S40S8A55lZs2jZ20Ytq2bLJ4S9l22i56cCr5lZG2ATD11MXkToGwAz6+nu77r7tcBCNmwCXRoo7TlIsXkeODc6PDOVcGgJd18QFacfN7MyQhv9PwaeAR6LCsC/Au4GnjKz9wgb+JXRfCuAZ8xsLOGwyyfpgvDQOu8wwkazEvgQOD3NS2o1/8hZwN1mthJ4lVA3ALiHcFjog2jvZwFZ7OrS3Veb2RnAo2bWlNCs+T8INYenzCy2p3Zx9JI/mlmvaNgrwIRsxSLFS62yihSImbVx99gZR1cAW7n7kAKHJQJoz0GkkI4wsysJv8OZpN8zEckr7TmIiEg1KkiLiEg1Sg4iIlKNkoOIiFSj5CAiItUoOYiISDX/D+TMgym6WDipAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "x = np.linspace(-10, 10)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.plot(act_infl, pred_infl, 'o', color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on Cora dataset')\n",
    "plt.title('Influence function on Citeseer dataset - Node Feature')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d0da5007",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.9525660930222389, pvalue=3.58041027161336e-73)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import scipy\n",
    "scipy.stats.spearmanr(act_infl, pred_infl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ed5b6c3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "948.6219225775423"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ori_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "33895450",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame(pred_infl).to_csv('result_tune/' + dataname + '_tune.csv', index = False, header = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5b60f53a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def fast_hess_cuda(x, logits):\n",
    "#     n = len(logits)\n",
    "\n",
    "#     K = logits.shape[1]\n",
    "\n",
    "#     D = x.shape[1]\n",
    "\n",
    "#     KD = K * D\n",
    "\n",
    "#     softmax_pred = softmax(logits, axis=1)\n",
    "\n",
    "#     temp_pred = softmax_pred[0]\n",
    "#     temp_x = x[0]\n",
    "\n",
    "#     temp_factor = cp.diag(temp_pred) - cp.einsum('i, j -> ij', temp_pred, temp_pred)\n",
    "\n",
    "#     temp_indiv_hessian = cp.einsum('ij, k, l -> ikjl', temp_factor, temp_x, temp_x).reshape(KD, KD)\n",
    "\n",
    "#     temp_off_diag = cp.einsum('ij, k -> ijk', temp_factor, temp_x).reshape(K, KD)\n",
    "\n",
    "#     temp_top_row = cp.concatenate([temp_indiv_hessian, temp_off_diag.T], axis=1)\n",
    "\n",
    "#     temp_bottom_row = cp.concatenate([temp_off_diag, temp_factor], axis=1)\n",
    "\n",
    "#     temp_indiv_hessian_all = cp.concatenate([temp_top_row, temp_bottom_row])\n",
    "\n",
    "#     del temp_off_diag\n",
    "#     del temp_top_row\n",
    "#     del temp_bottom_row\n",
    "#     del temp_indiv_hessian\n",
    "#     del temp_pred\n",
    "#     del temp_x\n",
    "#     del temp_factor\n",
    "\n",
    "#     for i in range(1, n):\n",
    "#         temp_pred = softmax_pred[i]\n",
    "\n",
    "#         temp_x = x[i]\n",
    "\n",
    "#         temp_factor = cp.diag(temp_pred) - cp.einsum('i, j -> ij', temp_pred, temp_pred)\n",
    "\n",
    "#         temp_indiv_hessian = cp.einsum('ij, k, l -> ikjl', temp_factor, temp_x, temp_x).reshape(KD, KD)\n",
    "\n",
    "#         # with intercept\n",
    "#         temp_off_diag = cp.einsum('ij, k -> ijk', temp_factor, temp_x).reshape(K, KD)\n",
    "\n",
    "#         temp_top_row = cp.concatenate([temp_indiv_hessian, temp_off_diag.T], axis=1)\n",
    "\n",
    "#         temp_bottom_row = cp.concatenate([temp_off_diag, temp_factor], axis=1)\n",
    "\n",
    "#         temp_indiv_hessian = cp.concatenate([temp_top_row, temp_bottom_row])\n",
    "\n",
    "#         temp_indiv_hessian_all += temp_indiv_hessian\n",
    "\n",
    "#         del temp_factor\n",
    "#         del temp_indiv_hessian\n",
    "#         del temp_off_diag\n",
    "#         del temp_top_row\n",
    "#         del temp_bottom_row\n",
    "#         del temp_pred\n",
    "#         del temp_x\n",
    "\n",
    "#     temp_indiv_hessian_all += cp.pad(cp.eye(KD, KD) * 1.0,\n",
    "#                                      [[0, K], [0, K]], mode='constant', constant_values=0)\n",
    "\n",
    "#     return temp_indiv_hessian_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f8499369",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('hyper_parameter/cora.csv')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
