{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "bd02fe7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy_ml\n",
    "import tensorflow.compat.v1 as tf\n",
    "import scipy\n",
    "\n",
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.utils import check_array\n",
    "from dataset import load_graph_dataset\n",
    "from scipy.linalg import cho_solve, cho_factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fa8453c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "FeatureExtractor = FeatureExtraction(num_layers=2, num_iter=100, lr=0.02, hidden_feat=20, device='cpu',\n",
    "                                         dataset='cora')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b620ef8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "accuracy: 0.793\n"
     ]
    }
   ],
   "source": [
    "FeatureExtractor.extract_feature()\n",
    "train_x, train_y, val_x, val_y, test_x, test_y = FeatureExtractor.preprocessed_data(save='feat.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "306e6a97",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.datasets import load_iris\n",
    "iris = load_iris()\n",
    "x, y = iris.data, iris.target\n",
    "\n",
    "# scaler = preprocessing.StandardScaler().fit(x)\n",
    "# x = scaler.transform(x)\n",
    "\n",
    "train_x, test_x, train_y,test_y = train_test_split(x,y,test_size=0.2,random_state=123)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5edac0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ee5ffb4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 3\n",
    "input_dim = train_x.shape[1]\n",
    "l2_reg = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a04a5052",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(max_iter=2048, random_state=0)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = LogisticRegression(random_state=0, \n",
    "                         penalty = 'l2',\n",
    "                         C = 1.0 / l2_reg,\n",
    "                         solver='lbfgs',\n",
    "                         warm_start=False,\n",
    "                         max_iter=2048,\n",
    "                         fit_intercept=True)\n",
    "clf.fit(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "01b52b9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_y = test_x @ clf.coef_.T + clf.intercept_\n",
    "np.allclose(softmax(out_y, axis = 1), clf.predict_proba(test_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ba088e2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_loss(model, logits, one_hot_labels, l2_reg = False, eps = 1e-15):\n",
    "    \n",
    "    sample_weights = np.ones(len(logits))\n",
    "    \n",
    "    softmax_pred = softmax(logits, axis = 1)\n",
    "    \n",
    "    print(softmax_pred)\n",
    "#     softmax_pred = np.clip(softmax_pred, eps, 1-eps)\n",
    "#     log_softmax_pred = np.log(softmax_pred)\n",
    "    \n",
    "    log_softmax_pred = np.log(softmax_pred + eps)\n",
    "    \n",
    "    indiv_cross_entropy = -np.sum( np.multiply(one_hot_labels, log_softmax_pred), axis = 1)\n",
    "    \n",
    "    indiv_cross_entropy = np.multiply(sample_weights, indiv_cross_entropy)\n",
    "    \n",
    "    cross_entropy = np.sum(indiv_cross_entropy)\n",
    "    \n",
    "    \"\"\"\n",
    "    The original way of calculating the log softmax loss is confusing, \n",
    "    \"\"\"\n",
    "#     indiv_loss = cross_entropy\n",
    "    \n",
    "#     total_loss_no_reg = tf.reduce_sum(tf.multiply(cross_entropy, sample_weights),\n",
    "#                                           name='total_loss_no_reg')\n",
    "    \n",
    "#     tf.add_to_collection('losses', indiv_loss)\n",
    "\n",
    "#     total_loss_reg = tf.add_n(tf.get_collection('losses'), name='total_loss_reg')\n",
    "#     avg_loss_reg = total_loss_reg / tf.cast(tf.shape(logits)[0], tf.float64)\n",
    "    if l2_reg:\n",
    "        cross_entropy += 1.0 * np.linalg.norm(model.coef_, ord = 2) / 2\n",
    "    \n",
    "    ave_cross_entropy = cross_entropy / len(logits)\n",
    "    \n",
    "    return cross_entropy, ave_cross_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "08974c4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[8.69599785e-04 5.89553333e-01 4.09577068e-01]\n",
      " [3.48875271e-05 8.86886664e-02 9.11276446e-01]\n",
      " [1.36067035e-05 3.84404512e-02 9.61545942e-01]\n",
      " [2.77894694e-02 9.28140911e-01 4.40696194e-02]\n",
      " [9.82051549e-01 1.79484320e-02 1.86674248e-08]\n",
      " [1.99620752e-03 4.94523677e-01 5.03480116e-01]\n",
      " [3.29312696e-03 8.95426065e-01 1.01280808e-01]\n",
      " [9.85231374e-01 1.47686010e-02 2.54012615e-08]\n",
      " [9.74341300e-01 2.56586301e-02 6.99246215e-08]\n",
      " [9.70124382e-03 8.89126744e-01 1.01172013e-01]\n",
      " [1.55092203e-04 1.65441924e-01 8.34402984e-01]\n",
      " [9.84081480e-01 1.59185095e-02 1.07602058e-08]\n",
      " [4.23939638e-02 9.09588505e-01 4.80175314e-02]\n",
      " [8.41081677e-05 1.53798764e-01 8.46117127e-01]\n",
      " [7.69779683e-07 1.72978112e-02 9.82701419e-01]\n",
      " [6.90673754e-06 2.64385253e-02 9.73554568e-01]\n",
      " [9.71950323e-01 2.80496553e-02 2.17388771e-08]\n",
      " [9.89262997e-01 1.07369935e-02 9.62986208e-09]\n",
      " [3.47953949e-03 7.85999547e-01 2.10520914e-01]\n",
      " [9.66727347e-01 3.32725636e-02 8.88758495e-08]\n",
      " [9.70183701e-01 2.98162605e-02 3.84588060e-08]\n",
      " [6.61336581e-04 4.94935569e-01 5.04403094e-01]\n",
      " [9.39424330e-01 6.05752713e-02 3.98972011e-07]\n",
      " [1.29780423e-03 4.39000233e-01 5.59701963e-01]\n",
      " [9.73104738e-01 2.68952111e-02 5.07118151e-08]\n",
      " [9.33310700e-01 6.66891217e-02 1.78685948e-07]\n",
      " [9.54436465e-01 4.55634174e-02 1.18066745e-07]\n",
      " [8.41983643e-04 2.36278472e-01 7.62879544e-01]\n",
      " [2.04439186e-07 6.79622957e-03 9.93203566e-01]\n",
      " [9.80009778e-01 1.99901963e-02 2.56041138e-08]]\n"
     ]
    }
   ],
   "source": [
    "_, calculated_loss = get_loss(clf, out_y, one_hot_labels_test, l2_reg = False)\n",
    "numpy_theoritic_loss = log_loss(test_y, softmax(out_y, axis = 1))\n",
    "# calculated_loss, numpy_theoritic_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76a60bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradients(model, x, logits, one_hot_labels, l2_reg = None, fit_intercept = True):\n",
    "    \"\"\"\n",
    "    Explicitly computes the softmax gradients.\n",
    "    grad_theta_i loss(x, y) = -([i == y] - softmax_i) * x\n",
    "    grad_b_i loss(x, y) = -([i == y] - softmax_i)\n",
    "    \"\"\"\n",
    "    if x.ndim == 1:\n",
    "        x = x.reshape(1, -1)\n",
    "\n",
    "    K = one_hot_labels.shape[1] # num_classes\n",
    "    \n",
    "    D = x.shape[1] # num_dimensions \n",
    "    \n",
    "    sample_weights = np.ones(len(logits)) # to be modify\n",
    "    sample_weights = sample_weights.reshape(-1, 1)\n",
    "    \n",
    "    softmax_pred = softmax(logits, axis = 1)\n",
    "    \n",
    "    factor = -(one_hot_labels - softmax_pred)\n",
    "    assert(factor.ndim == 2)\n",
    "    expand_factor = np.expand_dims(factor, axis = 2) # (n, num_classes, 1)\n",
    "    expand_x = np.expand_dims(x, 1) # (n, 1, num_dimension)\n",
    "    \n",
    "    indiv_grad = np.multiply(expand_factor, expand_x) # (n, num_classes, num_dimension)\n",
    "    indiv_grad = indiv_grad.reshape(-1, K * D) # (n, num_classes * num_dimension)\n",
    "    \n",
    "    weighted_indiv_grad = indiv_grad * sample_weights\n",
    "    \n",
    "    grad_reg = l2_reg * model.coef_.reshape(-1, K * D)\n",
    "    \n",
    "    if fit_intercept:\n",
    "        weighted_indiv_grad = np.concatenate([weighted_indiv_grad, factor], axis=1)\n",
    "        grad_reg = np.concatenate([grad_reg, np.zeros(K).reshape(1, -1)], axis = 1)\n",
    "    \n",
    "    total_grad = np.sum(weighted_indiv_grad, axis=0).reshape(1, -1)\n",
    "    \n",
    "    if l2_reg:\n",
    "        total_grad += grad_reg\n",
    "    \n",
    "    return total_grad, weighted_indiv_grad, grad_reg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c8a99e8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_grad, weighted_indiv_grad, grad_reg = gradients(clf, test_x, out_y, one_hot_labels_test, l2_reg=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dabb3ba2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1.49497502, -1.0886188 , -0.22489386,  0.03806188,  9.40424401,\n",
       "        5.0642749 ,  6.89862405,  2.62086456, -7.90926898, -3.97565611,\n",
       "       -6.6737302 , -2.65892643, -0.30326407,  1.55535829, -1.25209422])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(weighted_indiv_grad, axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "16937548",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hessian(model, x, logits, l2_reg = None, fit_intercept = True):\n",
    "    \"\"\"\n",
    "    Explicitly computes the softmax hessian.\n",
    "    grad_theta_i grad_theta_j loss(x, y)\n",
    "        = softmax_i ([i == j] - softmax_j) x x^T\n",
    "    grad_theta_i grad_b_j loss(x, y)\n",
    "        = softmax_i ([i == j] - softmax_j) x\n",
    "    grad_b_i grad_b_j loss(x, y)\n",
    "        = softmax_i ([i == j] - softmax_j)\n",
    "    \"\"\"\n",
    "    if x.ndim == 1:\n",
    "        x = x.reshape(1, -1)\n",
    "    n = len(logits)    \n",
    "    \n",
    "    K = logits.shape[1] # num_classes\n",
    "    \n",
    "    D = x.shape[1] # num_dimensions \n",
    "    \n",
    "    KD = K*D\n",
    "    sample_weights = np.ones(n) # to be modify\n",
    "#     sample_weights = sample_weights.reshape(-1, 1)\n",
    "    \n",
    "    softmax_pred = softmax(logits, axis = 1)\n",
    "    \n",
    "    factor = tf.linalg.diag(softmax_pred) - \\\n",
    "        tf.einsum('ai,aj->aij', softmax_pred, softmax_pred)               # (?, Kp, Kp)\n",
    "    indiv_hessian = tf.reshape(\n",
    "        tf.einsum('aij,ak,al->aikjl', factor, x, x),  # (?, Kp, D, Kp, D)\n",
    "        (-1, KD, KD))                                         # (?, KpD, KpD)\n",
    "\n",
    "    # Hessian of l2 regularization\n",
    "    hess_reg = l2_reg * tf.eye(KD, KD)\n",
    "\n",
    "    if fit_intercept:\n",
    "        off_diag = tf.reshape(\n",
    "            tf.einsum('aij,ak->aijk', factor, x),          # (?, Kp, Kp, D)\n",
    "            (-1, K, KD))                                      # (?, Kp, KpD)\n",
    "\n",
    "        top_row = tf.concat([indiv_hessian,\n",
    "                             tf.transpose(off_diag, (0, 2, 1))], axis=2)\n",
    "        bottom_row = tf.concat([off_diag, factor], axis=2)\n",
    "        indiv_hessian = tf.concat([top_row, bottom_row], axis=1)\n",
    "\n",
    "        hess_reg = tf.pad(hess_reg, [[0, K], [0, K]],\n",
    "                          mode=\"CONSTANT\", constant_values=0.0).numpy()\n",
    "\n",
    "    hessian_no_reg = tf.einsum('aij,a->ij', indiv_hessian, sample_weights).numpy()\n",
    "    \n",
    "    hessian_reg = hessian_no_reg + hess_reg\n",
    "    hessian_reg_term = hess_reg\n",
    "    \n",
    "    return hessian_no_reg, hessian_reg, hessian_reg_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1ab29186",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-04-14 19:00:33.242291: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
      "2022-04-14 19:00:33.242884: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
      "2022-04-14 19:00:33.244764: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory\n",
      "2022-04-14 19:00:33.244773: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "2022-04-14 19:00:33.244941: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "hessian_no_reg, hess_reg, hess_reg_term = hessian(clf, test_x, out_y, l2_reg=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "87bc0192",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 9.41253975e-01, -1.25176086e-02, -4.14596166e-02,\n",
       "        -1.82112107e-02,  4.38308088e-02,  8.97983162e-03,\n",
       "         2.88585436e-02,  1.13210741e-02,  1.49152160e-02,\n",
       "         3.53777695e-03,  1.26010730e-02,  6.89013655e-03,\n",
       "        -4.21425718e+00,  1.75265054e-01,  4.43192273e-01],\n",
       "       [-1.25176086e-02,  9.57676107e-01,  4.86700389e-02,\n",
       "         2.45135945e-02,  1.46514078e-02,  4.14574663e-02,\n",
       "        -4.27933722e-02, -2.14199791e-02, -2.13379926e-03,\n",
       "         8.66426535e-04, -5.87666671e-03, -3.09361538e-03,\n",
       "        -3.21746271e+00, -1.36899498e-01, -1.33757409e-01],\n",
       "       [-4.14596166e-02,  4.86700389e-02,  7.63040122e-01,\n",
       "        -9.82247080e-02,  1.16994764e-02, -5.68777363e-02,\n",
       "         1.95496908e-01,  7.44529005e-02,  2.97601402e-02,\n",
       "         8.20769746e-03,  4.14629700e-02,  2.37718075e-02,\n",
       "        -6.87339453e-01,  4.60871113e-01,  1.01006556e+00],\n",
       "       [-1.82112107e-02,  2.45135945e-02, -9.82247080e-02,\n",
       "         9.52184924e-01,  5.66831073e-03, -2.80247310e-02,\n",
       "         8.07567304e-02,  3.70911573e-02,  1.25428999e-02,\n",
       "         3.51113650e-03,  1.74679776e-02,  1.07239188e-02,\n",
       "         1.18458077e-01,  2.10640693e-01,  4.36506074e-01],\n",
       "       [ 4.38308088e-02,  1.46514078e-02,  1.16994764e-02,\n",
       "         5.66831073e-03,  7.88099958e-01, -2.30721836e-02,\n",
       "        -9.46438815e-02, -3.58129029e-02,  1.68069233e-01,\n",
       "         8.42077581e-03,  8.29444051e-02,  3.01445922e-02,\n",
       "        -2.52827605e-01, -3.73138605e+00, -9.77335366e-01],\n",
       "       [ 8.97983162e-03,  4.14574663e-02, -5.68777363e-02,\n",
       "        -2.80247310e-02, -2.30721836e-02,  8.87879372e-01,\n",
       "         2.17597608e-02, -2.16621948e-02,  1.40923520e-02,\n",
       "         7.06631618e-02,  3.51179755e-02,  4.96869258e-02,\n",
       "         6.34866597e-01, -2.14584612e+00, -2.36479999e-01],\n",
       "       [ 2.88585436e-02, -4.27933722e-02,  1.95496908e-01,\n",
       "         8.07567304e-02, -9.46438815e-02,  2.17597608e-02,\n",
       "         6.90100734e-01, -1.20226288e-01,  6.57853379e-02,\n",
       "         2.10336113e-02,  1.14402358e-01,  3.94695574e-02,\n",
       "         1.56228185e+00,  1.07075994e+00,  2.63441483e+00],\n",
       "       [ 1.13210741e-02, -2.14199791e-02,  7.44529005e-02,\n",
       "         3.70911573e-02, -3.58129029e-02, -2.16621948e-02,\n",
       "        -1.20226288e-01,  8.56781983e-01,  2.44918288e-02,\n",
       "         4.30821739e-02,  4.57733873e-02,  1.06126859e-01,\n",
       "        -4.77228095e-01, -2.02001399e-01, -2.98028734e-01],\n",
       "       [ 1.49152160e-02, -2.13379926e-03,  2.97601402e-02,\n",
       "         1.25428999e-02,  1.68069233e-01,  1.40923520e-02,\n",
       "         6.57853379e-02,  2.44918288e-02,  8.17015551e-01,\n",
       "        -1.19585528e-02, -9.55454780e-02, -3.70347287e-02,\n",
       "         7.61315891e-01, -1.56450166e-01, -3.19437716e+00],\n",
       "       [ 3.53777695e-03,  8.66426535e-04,  8.20769746e-03,\n",
       "         3.51113650e-03,  8.42077581e-03,  7.06631618e-02,\n",
       "         2.10336113e-02,  4.30821739e-02, -1.19585528e-02,\n",
       "         9.28470412e-01, -2.92413088e-02, -4.65933104e-02,\n",
       "         6.89945784e-01,  3.91451949e-01, -1.51687048e+00],\n",
       "       [ 1.26010730e-02, -5.87666671e-03,  4.14629700e-02,\n",
       "         1.74679776e-02,  8.29444051e-02,  3.51179755e-02,\n",
       "         1.14402358e-01,  4.57733873e-02, -9.55454780e-02,\n",
       "        -2.92413088e-02,  8.44134672e-01, -6.32413649e-02,\n",
       "         4.69787555e-01, -1.82635247e-01, -2.30762907e+00],\n",
       "       [ 6.89013655e-03, -3.09361538e-03,  2.37718075e-02,\n",
       "         1.07239188e-02,  3.01445922e-02,  4.96869258e-02,\n",
       "         3.94695574e-02,  1.06126859e-01, -3.70347287e-02,\n",
       "        -4.65933104e-02, -6.32413649e-02,  8.83149222e-01,\n",
       "        -8.19391380e-03, -3.74407720e-01, -5.06875332e-01],\n",
       "       [-4.23401829e+00, -3.21628321e+00, -6.87879188e-01,\n",
       "         1.21431237e-01, -6.94772676e-01,  6.19167669e-01,\n",
       "         1.55896514e+00, -2.79716658e-01,  6.74500372e-01,\n",
       "         7.02623631e-01,  4.33557830e-01, -9.39450177e-03,\n",
       "        -4.61113389e+14, -4.61113389e+14, -4.61113389e+14],\n",
       "       [ 1.50656070e-01, -1.41391812e-01,  4.58080786e-01,\n",
       "         2.13222940e-01, -4.16450916e+00, -2.15755120e+00,\n",
       "         1.06510077e+00, -5.93848070e-03, -2.40437509e-01,\n",
       "         4.04451109e-01, -2.18537767e-01, -3.74964382e-01,\n",
       "        -4.61113389e+14, -4.61113389e+14, -4.61113389e+14],\n",
       "       [ 4.33969608e-01, -1.36365238e-01,  1.01197537e+00,\n",
       "         4.40221687e-01, -1.41076498e+00, -2.52738583e-01,\n",
       "         2.63630210e+00, -9.99831835e-02, -3.27749522e+00,\n",
       "        -1.50538809e+00, -2.34363369e+00, -5.07918426e-01,\n",
       "        -4.61113389e+14, -4.61113389e+14, -4.61113389e+14]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.inv(hess_reg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d9a0bce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.metrics import log_loss\n",
    "from scipy.special import softmax, log_softmax\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7f6550b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>14</th>\n",
       "      <th>15</th>\n",
       "      <th>16</th>\n",
       "      <th>17</th>\n",
       "      <th>18</th>\n",
       "      <th>19</th>\n",
       "      <th>20</th>\n",
       "      <th>21</th>\n",
       "      <th>22</th>\n",
       "      <th>23</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.543700</td>\n",
       "      <td>0.708406</td>\n",
       "      <td>-0.892010</td>\n",
       "      <td>-0.620374</td>\n",
       "      <td>-1.058723</td>\n",
       "      <td>-0.966741</td>\n",
       "      <td>-1.040328</td>\n",
       "      <td>-0.622832</td>\n",
       "      <td>-0.596182</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.865241</td>\n",
       "      <td>0.611353</td>\n",
       "      <td>-0.611133</td>\n",
       "      <td>-0.917186</td>\n",
       "      <td>-1.020263</td>\n",
       "      <td>-0.989763</td>\n",
       "      <td>0.784655</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.548127</td>\n",
       "      <td>0.679302</td>\n",
       "      <td>-0.846138</td>\n",
       "      <td>-0.579255</td>\n",
       "      <td>-0.979052</td>\n",
       "      <td>-0.889304</td>\n",
       "      <td>-0.981078</td>\n",
       "      <td>-0.570468</td>\n",
       "      <td>-0.565127</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.817142</td>\n",
       "      <td>0.550100</td>\n",
       "      <td>-0.560664</td>\n",
       "      <td>-0.842593</td>\n",
       "      <td>-0.956245</td>\n",
       "      <td>-0.921767</td>\n",
       "      <td>0.704382</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.475942</td>\n",
       "      <td>0.642293</td>\n",
       "      <td>-0.787892</td>\n",
       "      <td>-0.554242</td>\n",
       "      <td>-0.978643</td>\n",
       "      <td>-0.897948</td>\n",
       "      <td>-0.946107</td>\n",
       "      <td>-0.530967</td>\n",
       "      <td>-0.529651</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.773411</td>\n",
       "      <td>0.555541</td>\n",
       "      <td>-0.546471</td>\n",
       "      <td>-0.839700</td>\n",
       "      <td>-0.941623</td>\n",
       "      <td>-0.909231</td>\n",
       "      <td>0.731721</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.560623</td>\n",
       "      <td>0.709882</td>\n",
       "      <td>-0.934610</td>\n",
       "      <td>-0.703681</td>\n",
       "      <td>-1.142573</td>\n",
       "      <td>-0.996351</td>\n",
       "      <td>-1.045104</td>\n",
       "      <td>-0.629294</td>\n",
       "      <td>-0.576431</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.960099</td>\n",
       "      <td>0.623167</td>\n",
       "      <td>-0.625813</td>\n",
       "      <td>-0.940791</td>\n",
       "      <td>-1.033401</td>\n",
       "      <td>-0.998360</td>\n",
       "      <td>0.838899</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.507652</td>\n",
       "      <td>0.585290</td>\n",
       "      <td>-0.765928</td>\n",
       "      <td>-0.612692</td>\n",
       "      <td>-0.991120</td>\n",
       "      <td>-0.815581</td>\n",
       "      <td>-0.850270</td>\n",
       "      <td>-0.422101</td>\n",
       "      <td>-0.434241</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.849878</td>\n",
       "      <td>0.450731</td>\n",
       "      <td>-0.476241</td>\n",
       "      <td>-0.736935</td>\n",
       "      <td>-0.876561</td>\n",
       "      <td>-0.815947</td>\n",
       "      <td>0.691426</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2703</th>\n",
       "      <td>3.0</td>\n",
       "      <td>1.029542</td>\n",
       "      <td>1.126835</td>\n",
       "      <td>-0.953499</td>\n",
       "      <td>-0.293646</td>\n",
       "      <td>-0.478533</td>\n",
       "      <td>-1.115570</td>\n",
       "      <td>-0.930043</td>\n",
       "      <td>-0.792080</td>\n",
       "      <td>-0.530297</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.819958</td>\n",
       "      <td>0.771034</td>\n",
       "      <td>-0.032367</td>\n",
       "      <td>-0.436413</td>\n",
       "      <td>-0.523535</td>\n",
       "      <td>-1.086509</td>\n",
       "      <td>0.463543</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2704</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.858308</td>\n",
       "      <td>1.062309</td>\n",
       "      <td>-0.748620</td>\n",
       "      <td>-0.102593</td>\n",
       "      <td>-0.528077</td>\n",
       "      <td>-0.932282</td>\n",
       "      <td>-0.855013</td>\n",
       "      <td>-0.477829</td>\n",
       "      <td>-0.504929</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.455939</td>\n",
       "      <td>0.510934</td>\n",
       "      <td>-0.003291</td>\n",
       "      <td>-0.415769</td>\n",
       "      <td>-0.476681</td>\n",
       "      <td>-0.874379</td>\n",
       "      <td>0.469643</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2705</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.550985</td>\n",
       "      <td>0.740175</td>\n",
       "      <td>-0.463918</td>\n",
       "      <td>0.000533</td>\n",
       "      <td>-0.477887</td>\n",
       "      <td>-0.573837</td>\n",
       "      <td>-0.607120</td>\n",
       "      <td>-0.150783</td>\n",
       "      <td>-0.322581</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.062911</td>\n",
       "      <td>0.216904</td>\n",
       "      <td>-0.000892</td>\n",
       "      <td>-0.274293</td>\n",
       "      <td>-0.343395</td>\n",
       "      <td>-0.538367</td>\n",
       "      <td>0.330357</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2706</th>\n",
       "      <td>3.0</td>\n",
       "      <td>1.396694</td>\n",
       "      <td>0.757272</td>\n",
       "      <td>-1.619163</td>\n",
       "      <td>0.541828</td>\n",
       "      <td>0.241305</td>\n",
       "      <td>-0.372257</td>\n",
       "      <td>-0.554431</td>\n",
       "      <td>-0.664398</td>\n",
       "      <td>0.554659</td>\n",
       "      <td>...</td>\n",
       "      <td>0.154074</td>\n",
       "      <td>0.366714</td>\n",
       "      <td>1.190225</td>\n",
       "      <td>0.905767</td>\n",
       "      <td>-1.066021</td>\n",
       "      <td>-0.529856</td>\n",
       "      <td>-1.175097</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2707</th>\n",
       "      <td>3.0</td>\n",
       "      <td>1.396694</td>\n",
       "      <td>0.757272</td>\n",
       "      <td>-1.619163</td>\n",
       "      <td>0.541828</td>\n",
       "      <td>0.241305</td>\n",
       "      <td>-0.372257</td>\n",
       "      <td>-0.554431</td>\n",
       "      <td>-0.664398</td>\n",
       "      <td>0.554659</td>\n",
       "      <td>...</td>\n",
       "      <td>0.154074</td>\n",
       "      <td>0.366714</td>\n",
       "      <td>1.190225</td>\n",
       "      <td>0.905767</td>\n",
       "      <td>-1.066021</td>\n",
       "      <td>-0.529856</td>\n",
       "      <td>-1.175097</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2708 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       0         1         2         3         4         5         6   \\\n",
       "0     4.0  0.543700  0.708406 -0.892010 -0.620374 -1.058723 -0.966741   \n",
       "1     4.0  0.548127  0.679302 -0.846138 -0.579255 -0.979052 -0.889304   \n",
       "2     4.0  0.475942  0.642293 -0.787892 -0.554242 -0.978643 -0.897948   \n",
       "3     4.0  0.560623  0.709882 -0.934610 -0.703681 -1.142573 -0.996351   \n",
       "4     4.0  0.507652  0.585290 -0.765928 -0.612692 -0.991120 -0.815581   \n",
       "...   ...       ...       ...       ...       ...       ...       ...   \n",
       "2703  3.0  1.029542  1.126835 -0.953499 -0.293646 -0.478533 -1.115570   \n",
       "2704  4.0  0.858308  1.062309 -0.748620 -0.102593 -0.528077 -0.932282   \n",
       "2705  4.0  0.550985  0.740175 -0.463918  0.000533 -0.477887 -0.573837   \n",
       "2706  3.0  1.396694  0.757272 -1.619163  0.541828  0.241305 -0.372257   \n",
       "2707  3.0  1.396694  0.757272 -1.619163  0.541828  0.241305 -0.372257   \n",
       "\n",
       "            7         8         9   ...        14        15        16  \\\n",
       "0    -1.040328 -0.622832 -0.596182  ... -0.865241  0.611353 -0.611133   \n",
       "1    -0.981078 -0.570468 -0.565127  ... -0.817142  0.550100 -0.560664   \n",
       "2    -0.946107 -0.530967 -0.529651  ... -0.773411  0.555541 -0.546471   \n",
       "3    -1.045104 -0.629294 -0.576431  ... -0.960099  0.623167 -0.625813   \n",
       "4    -0.850270 -0.422101 -0.434241  ... -0.849878  0.450731 -0.476241   \n",
       "...        ...       ...       ...  ...       ...       ...       ...   \n",
       "2703 -0.930043 -0.792080 -0.530297  ... -0.819958  0.771034 -0.032367   \n",
       "2704 -0.855013 -0.477829 -0.504929  ... -0.455939  0.510934 -0.003291   \n",
       "2705 -0.607120 -0.150783 -0.322581  ... -0.062911  0.216904 -0.000892   \n",
       "2706 -0.554431 -0.664398  0.554659  ...  0.154074  0.366714  1.190225   \n",
       "2707 -0.554431 -0.664398  0.554659  ...  0.154074  0.366714  1.190225   \n",
       "\n",
       "            17        18        19        20   21   22   23  \n",
       "0    -0.917186 -1.020263 -0.989763  0.784655  0.0  0.0  1.0  \n",
       "1    -0.842593 -0.956245 -0.921767  0.704382  0.0  0.0  1.0  \n",
       "2    -0.839700 -0.941623 -0.909231  0.731721  0.0  1.0  0.0  \n",
       "3    -0.940791 -1.033401 -0.998360  0.838899  1.0  0.0  0.0  \n",
       "4    -0.736935 -0.876561 -0.815947  0.691426  0.0  0.0  1.0  \n",
       "...        ...       ...       ...       ...  ...  ...  ...  \n",
       "2703 -0.436413 -0.523535 -1.086509  0.463543  0.0  1.0  0.0  \n",
       "2704 -0.415769 -0.476681 -0.874379  0.469643  0.0  0.0  1.0  \n",
       "2705 -0.274293 -0.343395 -0.538367  0.330357  0.0  0.0  0.0  \n",
       "2706  0.905767 -1.066021 -0.529856 -1.175097  0.0  0.0  0.0  \n",
       "2707  0.905767 -1.066021 -0.529856 -1.175097  0.0  0.0  0.0  \n",
       "\n",
       "[2708 rows x 24 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv('feat.csv', header = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ac8efdd6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n",
      "accuracy: 0.793\n"
     ]
    }
   ],
   "source": [
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "\n",
    "FeatureExtractor = FeatureExtraction(num_layers=2, num_iter=100, lr=0.02, hidden_feat=20, device='cpu',\n",
    "                                         dataset='cora')\n",
    "FeatureExtractor.extract_feature()\n",
    "train_x, train_y, val_x, val_y, test_x, test_y = FeatureExtractor.preprocessed_data(save='feat.csv')\n",
    "\n",
    "\n",
    "train_x = train_x.astype(np.float64)\n",
    "test_x = test_x.astype(np.float64)\n",
    "train_y = train_y.astype(np.float64)\n",
    "test_y = test_y.astype(np.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4d4c5da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# iris = load_iris()\n",
    "# x, y = iris.data, iris.target\n",
    "# train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2, random_state=123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "57b4c8c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "lr = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "logits_test_y = test_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "logits_train_y = train_x @ lr.model.coef_.T + lr.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr.log_loss(logits_test_y, one_hot_labels_test)\n",
    "\n",
    "numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "eb3cb47e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_total_grad, train_indiv_grad = lr.grad(train_x, logits_train_y, one_hot_labels_train)\n",
    "val_loss_total_grad, val_loss_indiv_grad = lr.grad(test_x, logits_test_y, one_hot_labels_test)\n",
    "\n",
    "hessian_no_reg, hess, hessian_reg_term = lr.hess(train_x, logits_train_y)\n",
    "\n",
    "loss_grad_hvp = lr.get_inv_hvp(hess, val_loss_total_grad.T)\n",
    "\n",
    "pred_infl = train_indiv_grad.dot(loss_grad_hvp)\n",
    "\n",
    "pred_infl = list(pred_infl.reshape(-1))\n",
    "#\n",
    "num_train = len(train_x)\n",
    "act_infl = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2ebee3a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(num_train):\n",
    "    lr_new = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "    train_x_new = np.delete(train_x, i, axis = 0)\n",
    "    train_y_new = np.delete(train_y, i)\n",
    "    lr_new.fit(train_x_new, train_y_new)\n",
    "    \n",
    "    logits_test_y_new = test_x @ lr_new.model.coef_.T + lr_new.model.intercept_\n",
    "    \n",
    "    \n",
    "    new_ori_val_loss, new_ave_ori_val_loss = lr_new.log_loss(logits_test_y_new, one_hot_labels_test)\n",
    "    act_infl.append(new_ori_val_loss - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e32c8cfe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEWCAYAAABv+EDhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxNklEQVR4nO3deXwU9f3H8dcnEI5wyikiIcghtxeCSrWoeB9Ura0Wbyta9afVtoiiVatWbD2q1gsvPKLWAoqi9cCKeCtYINwihvs+w53j8/tjJrrEbLKQ7G6SfT8fj33szuzszGd2k/nMfGfm8zV3R0REUk9asgMQEZHkUAIQEUlRSgAiIilKCUBEJEUpAYiIpCglABGRFKUEkELMrLWZTTKzPDO7z8xuM7MXkx1XrErGn+Blbzaz/RK5zMpmZo+b2S2VMJ+JZvbbyohJkksJoJozs1wzGxjj5EOANUBjd/9DHMOKl4TEX9oGzt0buvuCeC2zoswsy8zczGpHm8bdr3D3OxIc1+78fVb55dQ0SgCppT0wy6vv3X/VPf6kMbNayY5Bqh4lgBrEzC4ys0/M7F4zW29m35vZSeF7o4ALgaFhc8bAEp8dYGZLSoz7Ya/KzNLMbJiZfWdma83sVTNrFr5XvPd5oZktMrM1ZjY8Yj61zOym8LN5ZjbFzNqF73U1s/fNbJ2ZzTWzX0VZt5/Eb2ajzOzOaOsQxv9HM5tuZhvN7F9mVi/i/UFmNtXMNoWxnWhmdwFHAv8Ml/PPcFo3s07h6yZm9ryZrTazhWZ2s5mllfcbRFmvbuERxwYzm2lmp0eus5k9YmZvhd/bl2bWMdq8Sn5fZvaYmb1tZluAoyO/LzNrYWbjw+WuM7OPi9ehlHkdZ2Zzwu/wn4BFvNfRzP4b/k2sMbNsM2savvcCkAm8GX6XQ8Px/zazFeH8JplZj4j5nWxms8L1XWpmf4x479Tw99pgZp+ZWe+yliMxcHc9qvEDyAUGhq8vAvKBy4BawO+AZYCF748C7oz47G3Ai+HrAcCSMub9e+ALYF+gLvAE8HL4XhbgwJNAfeAAYAfQLXz/T0AOsD/BxuMAoDnQAFgMXAzUBg4maOLpEWVdS8ZfcniXdQjj/wrYB2gGzAauCN/rC2wEjiPYEWoLdA3fmwj8tsSyHegUvn4eGAc0Ctd9HnBpLL9BiXmmA/OBm4A6wDFAHrB/xPqtC2OtDWQDr0T5bop/g9oRn90I9A/Xr17k9wXcDTwexpBOkPRKi7EFsAn4ZTjddUBB8fcDdAq/w7pAS2AS8I/S/oYixl0Sfnd1gX8AUyPeWw4cGb7eCzg4fH0wsAroF36vF4bzrhttOXqU/9ARQM2z0N2fdPdC4DmgDdC6EuZ7OTDc3Ze4+w6C5PFL27XN+XZ33+bu04BpBBt6gN8CN7v7XA9Mc/e1wKlArrs/6+4F7v4NMIZgY1NZHnL3Ze6+DngTODAcfynwjLu/7+5F7r7U3eeUNzMLmlJ+Ddzo7nnungvcB5wfMVmsv8FhQENghLvvdPf/AuOBcyOmGevuX7l7AUECOPCns4lqnLt/Gq7f9hLv5YdxtXf3fHf/2MMtaQknEzS7jXb3fIIN9oriN919fvgd7nD31cD9wM/LCsrdnwm/u+K/owPMrElEXN3NrLG7rw//JiBIqE+4+5fuXujuzxHsZBy2G9+HlKAEUPNE/nNuDV82rIT5tgdeCw+/NxDsTRey64ZtRcTrrRHLbQd8F2We/YrnGc53MLB3JcS7pzGVpwXB3vrCiHELCY4gfrLMcn6DfYDF7l4Uy7zYNf5YLC7jvb8THH28Z2YLzGxYlOn2iZxPmCR+GDazVmb2Sthcswl4keA7KlXYHDgibHLbRLDnTsRnziJIOgvN7CMzOzwc3x74Q4m/lXZhfLKHlACk2BYgo3gg3NNtGfH+YuAkd28a8ajn7ktjmPdioLS268XARyXm2dDdf7cnMbN7iSNaTBA0pUSzhmAvtX3EuEwglu+hpGVAuxJt73s6r9JEXY9wD/wP7r4fcBpwvZkdW8qkywk2tACYmUUOEzQlOdDb3RsD5xFxjqCUGH4DDAIGAk0Imq4o/oy7f+3ug4BWwOvAq+H7i4G7SvytZLj7y+Wtq0SnBCDF5gH1zOwUM0sHbiZooy32OHCXmbUHMLOWZjYoxnk/BdxhZp0t0NvMmhM0d3Qxs/PNLD18HGpm3WKc71TgZDNrZmZ7E5yniNXTwMVmdqwFJ7jbmlnX8L2VQKnX/IfNOq8SfBeNwu/jeoI93931JUESGxqu+wCCjfErezCv3RKeUO0UbtA3ERzNFZYy6VtADzM7M2zuu4ZdE20jYDOwwczaEpzviVTyu2xE0HSzliB5/zUipjpmNtjMmoTNTcVxQXB+6Qoz6xf+DTUI/1YbRVmOxEAJQABw943AlQQb66UEG6bIq4IeBN4gaDLIIzgh3C/G2d9PsNF8j+Cf+mmgvrvnAccD5xDsDa8A7mHXxFOWFwjONeSG8/5XjJ/D3b8iOPn8AMHJ0o/4ca/+QYLzG+vN7KFSPv5/BN/PAuAT4CXgmViXHRHDTuB04CSCI4tHgQtiORdRCToDEwg23p8Dj7r7xFJiXAOcDYwg2Gh3Bj6NmOR2ghO0GwmSxdgSs7gbuDlstvkjwQn0hQR/Y7MI/o4inQ/khs1DVxAcUeDukwnOA/wTWE/QfHVRGcuRGBRfHSIiIilGRwAiIilKCUBEJEUpAYiIpCglABGRFBW1cmBV1KJFC8/Kykp2GCIi1cqUKVPWuHvLkuOrVQLIyspi8uTJyQ5DRKRaMbOFpY1XE5CISIpSAhARSVFKACIiKUoJQEQkRSkBiIikqLgnADNrZ2YfmtlsC7q8uzYc38yCrgC/DZ/3incsIiLyo0QcARQAf3D3bgS991xlZt2BYcAH7t4Z+CAcFhGRBIl7AnD35cXduoXlf2cT9Hg0iKC7PMLnX8Q7FhGR6mb9lp3c/uZMNm3Pr/R5J/QcgJllAQcRdITR2t2XQ5AkCHoAKu0zQ8xssplNXr16dcJiFRFJJnfnrenLOe6Bj3jh84V8tWBdpS8jYXcCm1lDgg6/f+/um4KOiMrn7iOBkQB9+vRR5wUiUuOt3LSdW16fwXuzVtKrbRNeuLQf3do0rvTlJCQBhF0MjgGy3b24x6CVZtbG3ZebWRtgVSJiERGpqtydVycv5s63ZrOzoIgbT+rKpT/rQO1a8WmsiXsCCPscfRqY7e73R7z1BnAhQVdzFwLj4h2LiEhVtWjtVm58bTqfzl9L3w7NuOes3nRo0SCuy0zEEUB/gn4+c8xsajjuJoIN/6tmdimwiKDfURGRlFJY5Iz6LJd7351LrTTjzl/05Dd9M0lLi62ZvCLingDc/RMg2pocG+/li4hUVfNW5jF09HSmLt7AMV1bcecverJP0/oJW361KgctIlIT7Cwo4vGPvuPh/35Lw7q1efCcAzn9gH2I9eKYyqIEICKSQNOXbGDo6OnMWZHHaQfsw22ndad5w7pJiUUJQEQkAbbtLOQfE+bx5McLaNmoLk9e0IfjurdOakxKACIicfbFgrUMGzOd3LVbObdvJjee3JXG9dKTHZYSgIhIvORtz2fEf+aQ/eUi2jfP4KXL+nFExxbJDusHSgAiInHw3zkrGf7aDFZu2s5lR3bg+uP2p36dWskOaxdKACIilWjt5h38Zfwsxk1dRpfWDXnsvP4c2K5pssMqlRKAiEglcHfenL6c296YSd72fH4/sDNXDuhEndpVt98tJQARkQpavnEbt7w+gwmzV3FAu6b87aze7L93o2SHVS4lABGRPVRU5Lzy9WLufns2+UVF3HxKNy7u34FaCSjjUBmUAERE9kDumi0MGzudLxas4/D9mjPirF60bx7f4m2VTQlARGQ3FBY5z3zyPfe9P5f0tDRGnNmLXx/aLuFlHCqDEoCISIzmrshj6OhpTFuykYHdWnHnL3qxd5N6yQ5rjykBiIiUY2dBEY98OJ9HJ86ncb10Hj73IE7t3aZa7vVHUgIQESnD1MUbGDp6GvNWbuaMg9pyy6ndadagTrLDqhRKACIipdi2s5D73pvLM59+T+vG9Xjmoj4c0zW5xdsqmxKAiEgJn323hmFjcli0bivnHZbJDSd2pVEVKN5W2ZQARERCm7bnc/fbs3n5q8VkNc/glSGHcdh+zZMdVtwoAYiIABNmrWT46zmsztvB5T/fj+sGdqFeetUq3lbZlABEJKWt3byD296cxZvTltF170Y8eUEfeu/bNNlhJYQSgIikJHfnjWnLuO2NmWzZUcgfjuvC5T/vWKWLt1U2JQARSTnLNmzj5tdn8N85qzgoMyje1rl11S/eVtmUAEQkZRQVOS99tYgR/5lDYZHz51O7c+ERWdWmeFtlUwIQkZTw/Zot3DBmOl99v47+nZpz9xm9yWyekeywkkoJQERqtILCIp765HseeH8edWqn8bezenN2n32rfRmHyqAEICI11qxlm7hhzHRylm7k+O6tueMXPWnduPoWb6tsSgAiUuPsKCjkn/+dz2MTv6NpRjqP/OZgTu61t/b6S1ACEJEaZcrC9dwwZjrzV23mzIPbcssp3dmrhhRvq2xKACJSI2zZUcC9781l1Ge57NOkPqMuPpQB+7dKdlhVWtwTgJk9A5wKrHL3nuG424DLgNXhZDe5+9vxjkVEaqaPv13NjWNzWLJ+Gxcc3p6hJ3alYV3t35YnEbe8jQJOLGX8A+5+YPjQxl9EdtvGrfkMHT2N85/+ijq10nj18sP5y6CeSd34Z2dDVhakpQXP2dlJC6Vccf+W3H2SmWXFezkiklrembGCW8bNYN2WnVw5oCPXHNs56cXbsrNhyBDYujUYXrgwGAYYPDh5cUWTzKIXV5vZdDN7xsz2ijaRmQ0xs8lmNnn16tXRJhORFLE6bwdXZX/DFS9OoWXDuoy7qj9DT+ya9I0/wPDhP278i23dGoyviszd47+Q4AhgfMQ5gNbAGsCBO4A27n5JefPp06ePT548OZ6hikgV5e6M/WYpfxk/i235hVx7bGeGHLUf6bWqTvG2tDQobZNqBkVFiY/nx+XbFHfvU3J8uU1AZnY28I6755nZzcDBwJ3u/s2eBuPuKyPm/yQwfk/nJSI135L1Wxn+2gw+mreaQ9rvxT1n9aZTq4bJDusnMjODZp/SxldFsaTOW8KN/8+AE4DngMcqslAzaxMxeAYwoyLzE5GaqajIef7zXE54YBJf567jttO68+/LD6+SG3+Au+6CjBLlhTIygvFVUSwngQvD51OAx9x9XHgZZ0zM7GVgANDCzJYAtwIDzOxAgiagXODy2EMWkVTw3erNDBszna9z13Nk5xb89YxetGtWtYu3FZ/oHT4cFi0K9vzvuqtqngCGGM4BmNl4YCkwEDgE2AZ85e4HxD+8XekcgEjNl19YxMhJC3jwg2+pn16LW07tzlkHt1UZhwrY43MAwK8IruO/1903hM03f6rsAEVEZizdyA1jpjNz2SZO6rk3tw/qQatGKt4WL7EkgDbAW+6+w8wGAL2B5+MZlIiklu35hTz0wbc8MWkBe2XU4bHBB3NSrzblf1AqJJYEMAboY2adgKeBN4CXgJPjGZiIpIbJuesYOmY6C1Zv4exD9mX4Kd1omqHibYkQSwIocvcCMzsT+Ie7P2xm/4t3YCJSs23eUcDf35nD818sZJ8m9Xn+kr4c1aVlssNKKbEkgHwzOxe4ADgtHJcev5BEpKb7aN5qbhqbw7KN27jw8Cz+dML+NFDxtoSL5Ru/GLgCuMvdvzezDsCL8Q1LRGqiDVt3csf42Yz5ZgkdWzZg9BWHc0j7ZskOK2WVmwDcfZaZ/RHoYmY9gbnuPiL+oYlITfKfnOXcMm4mG7bu5OqjO3H1MZ2qRP2eVBZLKYgBBHf/5gIGtDOzC919UlwjE5EaYdWm7fx53EzembmCnm0b89wlh9JjnybJDkuIrQnoPuB4d58LYGZdgJcJbgoTESmVuzN6yhLuGD+L7QVF3HBiVy47sgO1q1DxtlQXSwJIL974A7j7PDPTSWARiWrxuq3c9FoOH3+7hr5Zzbj7rF50bFk16/ekslgSwGQzexp4IRweDEyJX0giUl0VFjkvfJ7L396diwF3DOrB4H7tSUtTGYeqKJYE8DvgKuAagnMAk4BH4xmUiFQ/81flccOYHKYsXM/Pu7Tkr2f2om3T+skOS8oQy1VAO4D7w4eIyC5+KN424Vsy6tbi/l8dwBkHqXhbdRA1AZhZDkG55lK5e++4RCQi1UbOko0MHTOd2cs3cUqvNtx2eg9aNqqb7LAkRmUdAZyasChEpFrZnl/IPyZ8y5MfL6B5gzo8cf4hnNBj72SHJbspagJw91I6NhORVPflgrUMG5vD92u28Os+7bjplG40qa8LA6sjFd8QkZjkbc/nnnfm8OIXi2jXrD7Zv+1H/04tkh2WVIASgIiU68M5qxj+Wg7LN23nkv4d+OMJXcioo81HdadfUESiWrdlJ3eMn8Vr/1tK51YNGfO7Izg4c69khyWVJJZaQP2B24D24fQGuLvvF9/QRCRZ3J23cpZz67iZbNyWzzXHduaqoztSt7aKt9UksRwBPA1cR3D3b2F8wxGRZFu5aTs3vz6D92etpPe+TXjxt/3o1qZxssOSOIglAWx09//EPRIRSSp359XJi7nzrdnsLCjippO7ckl/FW+ryWJJAB+a2d+BscCO4pHu/k3cohKRhFq0divDxk7ns+/W0q9DM+45qzdZLRokOyyJs1gSQL/wuU/EOAeOqfxwRCSRCoucZz/9nnvfm0vttDT+ekYvzjm0nYq3pYhYagEdnYhARCSx5q3MY+jo6UxdvIFjurbirjN60qaJirelkrJqAZ3n7i+a2fWlve/uKg4nUg3tLCjisYnf8c8Pv6Vh3do8eM6BnH7APireloLKOgIobgBslIhARCT+pi3ewA1jpjNnRR6nH7APt57WneYNVbwtVZVVC+iJ8Pn2xIUjIvGwbWchD0yYx1MfL6BVo3o8dUEfBnZvneywJMl0J7BIDff5d2u5cex0ctdu5dy+mdx4clca11PxNklAAjCzZwhKS69y957huGbAv4AsIBf4lbuvj3csIqlk0/Z8RvxnDi99uYjMZhm8dFk/juio4m3yo0Tc4TEKOLHEuGHAB+7eGfggHBaRSvLB7JUcf/8kXvlqEZcd2YF3f3+UNv7yE7HUAmoN/BXYx91PMrPuwOHu/nQsC3D3SWaWVWL0IGBA+Po5YCJwQ4wxi0gUazfv4PY3Z/HGtGXs37oRj59/CAe2a5rssKSKiqUJaBTwLDA8HJ5H0HwTUwKIorW7Lwdw9+Vm1irahGY2BBgCkJmZWYFFitRc7s4b05Zx+5uzyNuez3UDu/C7AR2pU1tlHCS6WBJAC3d/1cxuBHD3AjNLWFE4dx8JjATo06dP1D6KRVLV8o3buPm1GXwwZxUHtGvK33/Zmy6tdfW2lC+WBLDFzJoTdhBvZocBGyu43JVm1ibc+28DrKrg/ERSTlGR88rXi7n77dnkFxVx8ynduLh/B2qpjIPEKJYEcD3wBtDRzD4FWgK/rOBy3wAuBEaEz+MqOD+RlJK7ZgvDxk7niwXrOKJjc0ac2ZvM5hnJDkuqmVhqAX1jZj8H9ifoDGauu+fHugAze5nghG8LM1sC3Eqw4X/VzC4FFgFn70HsIimnoLCIZz/N5b7355KelsaIM3vx60PbqYyD7JFYrgI6s8SoLma2Echx93Kbbtz93ChvHRtDfCISmrNiEzeMns60JRsZ2K01d/6iJ3s3qZfssKQai6UJ6FLgcODDcHgA8AVBIviLu78Qp9hEBNhRUMgjH37Hox/Op0n9dB4+9yBO7d1Ge/1SYbEkgCKgm7uvhB/uC3iMoJ+ASYASgEic/G/Rem4YM515KzdzxkFtueXU7jRrUCfZYUkNEUsCyCre+IdWAV3cfZ2ZxXwuQERit3VnAfe/N49nPv2e1o3r8cxFfTimq4q3SeWKJQF8bGbjgX+Hw2cBk8ysAbAhXoGJpKrP5q9h2NgcFq3byuB+mQw7qSuNVLxN4iCWBHAVwUa/P8FVQM8DY9zdAfUWJlJJNm7L5+63Z/PK14vJap7BK0MO47D9mic7LKnBYrkM1IHR4UNE4uD9WSu5+fUcVuft4PKf78d1A7tQL71WssOSGi7Wy0DvAVoRHAEYQV5oHOfYRGq8NZt3cNsbMxk/fTld927Ekxf0ofe+TZMdlqSIWJqA/gac5u6z4x2MSKpwd16fupTb35zFlh0FXH9cF674uYq3SWLFkgBWauMvUnmWbtjG8NdymDh3NQdlNuVvZ/Wms4q3SRLEkgAmm9m/gNeBHcUj3X1svIISqYmKipzsrxYx4u3ZFDncelp3Ljg8S8XbJGliSQCNga3A8RHjHFACEInRgtWbGTYmh69y1/GzTi24+8xetGum4m2SXLFcBXRxIgIRqYkKCot46pPveeD9edStncbfzurN2X32VRkHqRJiuQqoHkE9oB7AD5Wn3P2SOMYlUu3NWraJoWOmMWPpJk7o0Zo7BvWkVWMVb5OqI5YmoBeAOcAJwF+AwYBOCotEsaOgkH/+dz6PTfyOphl1eGzwwZzUq02ywxL5iVgSQCd3P9vMBrn7c2b2EvBuvAMTqY6mLFzHDWNymL9qM2cdvC+3nNqNphkq3iZVUywJoLjg2wYz6wmsALLiFpFINbRlRwF/f3cuz32eyz5N6jPq4kMZsH+rZIclUqZYEsBIM9sLuIWgK8eGwJ/jGpVINfLxt6u5cWwOS9Zv48LD2/OnE7vSsG4s/1oiyRXLVUBPhS8/AvaLbzgi1cfGrfnc+dYs/j1lCfu1bMC/rzicQ7OaJTsskZjFchVQXYJqoFmR07v7X+IXlkjV9s6MFdwybgbrtuzkygEduebYzireJtVOLMep44CNwBQi7gQWSUWr8rZz2xszeTtnBd3bNObZiw6lZ9smyQ5LZI/EkgD2dfcT4x6JSBXm7oz5Zil3jJ/FtvxC/nTC/gw5aj/Sa6l4m1RfsSSAz8ysl7vnxD0akSpoyfqt3PTaDCbNW02f9nsx4qzedGrVMNlhiVRY1ARgZjkENX9qAxeb2QKCJqDi/gB6JyZEkeQoKnJe+GIh97wzB4DbT+/B+Ye1J03F26SGKOsI4NSERSFSxXy3ejM3jJ7O5IXrOapLS/56Rk/23UvF26RmiZoA3H0hgJkdBsx097xwuBHQHViYkAhFEii/sIgnP17APyZ8S/30Wtx39gGceXBbFW+TGimWcwCPAQdHDG8pZZxItTdj6UZuGDOdmcs2cXKvvbn99J60bFQ32WGJxE0sCcDCjuEBcPciM9NtjlJjbM8v5KEPvuWJSQto1qAOj593MCf2VPE2qfli2ZAvMLNrCPb6Aa4EFsQvJJHEmZy7jqFjprNg9RbOPmRfbj6lO00y0pMdlkhCxJIArgAeAm4muCroA2BIPIMSibfNOwr4+ztzeP6LhbRtWp8XLu3LkZ1bJjsskYSKpRbQKuCceCzczHKBPKAQKHD3PvFYjkikj+at5qaxOSzbuI0LD8/iTyfsTwMVb5MUVBX+6o929zXJDkJqvg1bd3LH+NmM+WYJHVs2YPQVh3NIexVvk9Sl+9glJbyds5yB93/EuKlLufroTrx1zZHM+aQZWVmQlgZZWZCdnewoRRIr2UcADrxnZg484e4jS05gZkMIzzlkZmYmODyp7lZt2s6fx83knZkr6Nm2Mc9d0pce+zQhOxuGDIGtW4PpFi4MhgEGD05evCKJZBFXeO76htn1ZX3Q3e+v8MLN9nH3ZWbWCngf+D93nxRt+j59+vjkyZMrulhJAe7Ov6cs4c7xs9hRUMR1x3Xhtz/rQO2weFtWVrDRL6l9e8jNTWioInFnZlNKO8da1hFAo/B5f+BQgt7AAE4Dom6kd4e7LwufV5nZa0Dfypq3pK7F67Zy02s5fPztGvpmNWPEWb3Yr+WuxdsWLSr9s9HGi9REZZWCuB3AzN4DDo4oBXEb8O+KLtjMGgBp7p4Xvj4eUCczsscKi5znP8/lb+/MJc3gjl/0ZHDfzFKLt2Vmln4EoFZGSSWxnAPIBHZGDO+kcjqFbw28FtZYqQ285O7vVMJ8JQV9uzKPG8ZM55tFGxiwf0vuOqMXbZvWjzr9XXfteg4AICMjGC+SKmJJAC8AX4VNNA6cATxf0QW7+wLggIrOR1JbfmERj0/8jof/O5+MurV44NcH8IsDyy/eVnyid/jwoNknMzPY+OsEsKSSqCeBd5nI7GDgyHBwkrv/L65RRaGTwBIpZ8lG/jR6GnNW5NE5rztz3sxi6RLTxlykhD05CRwpA9jk7s+aWUsz6+Du31duiCKx2Z5fyAMT5vHkpAW0aFiXX+/Vn38+3FSXdIrspnITgJndCvQhuBroWSAdeBHoH9/QRH7qywVrGTY2h+/XbOGcQ9tx48ndOKBb+i5t+RC07Q8frgQgUpZYjgDOAA4CvoHg0s2wUxiRhMnbns8978zhxS8W0a5ZfbJ/24/+nVoAuqRTZE/FkgB2uruHd+sWX74pkjAfzlnF8NdyWL5pO5f+rAN/OL4LGXV+/NPVJZ0ieyaWWkCvmtkTQFMzuwyYADwV37BEYN2WnVz3r6lcPOprGtStzZjfHcEtp3bfZeMPwQnfjBLd9eqSTpHyxVIO+l4zOw7YRHAe4M/u/n7cI5OU5e78ccQGHr6nHvkbD2CvVj04795aHJxZ+v6KLukU2TPlXgZqZve4+w3ljUsEXQZa8z361E6uv87Ysbk28OO1/BkZcOGF8Pbb2siL7K5ol4HG0gR0XCnjTqp4SCI/cneu/ssarr6yFjs2pxO58Yfgqp7HHw/a+t1/vNRTJZxF9lzUBGBmvzOzHKCrmU2PeHwP5CQuRKnpHnpiBw1b7OSRW5vj+bWiTlfyYLX4Uk8R2TNlnQN4CfgPcDcwLGJ8nruvi2tUkhIKi5xjz9rER683puQef6x0qafInot6BODuG909F3gQWOfuC919IZBvZv0SFaDUTPNW5tH3knkxb/yjlfbRpZ4iey6WcwCPAZsjhreE40R2286CIh6c8C2nPPQxOa9nEsvGv3lzuOIKXeopUtliSQDmEZcKuXsRye9KUqqhaYs3cNrDn/DAhHmc1LMNBZvqlTl9rfB0QMOG0L8/jBwZ9NhlFjyPHKmrgEQqIpYEsMDMrjGz9PBxLbAg3oFJzbFtZyF3vTWLMx79lI3b8nnqgj48dO5BZGZG3/uvUwcKC4PXkcXdcnOhqCh41sZfpGJiSQBXAEcAS4ElQD/CTtpFypKdDXu3LSSjbhq3/iaLXtt78d71RzGwe2ug9Dt4zaBBA9i5c9fxuuJHpPKVmwDcfZW7n+Purdy9tbv/xt1XJSI4qb6eGlXARZcWsXJZLcAo3JTBhJGZvDkm/YdpBg/+abPOCy/wk8qexXTFj0jlinonsJkNdfe/mdnDBD2B7cLdr4l3cCXpTuCqLTu7uByD4zj4T/cv2rcPmm/KkpVVenG3WD4rIj+1Jx3CzA6ftcWVcmVnw2WXOdu2GcGVPaW378eyF6/+ekUSI2oCcPc3w+fnEheOVEfZ2c4FF0BRUfmXdMZy3b6Ku4kkRtQEYGZvUkrTTzF3Pz0uEUm18shTO7n2qloUFUUv4VBsd/biBw/WBl8k3spqAro3fD4T2JugG0iAc4HcOMYk1UBRkfPy14u47o+tKNxZJ+p0tWoFl21qL16k6imrCegjADO7w92PinjrTTObFPfIpMrKXbOFc29Yytej96WwjJu5MjJ0s5ZIVRbLfQAtzWy/4gEz6wC0jF9IUtVceSXUrg1mTq1aTrc+W/nyhY4Ubsog2sneWrW08Rep6mIp6XAdMNHMiu/+zQIuj1tEUqVceSU89kPlJ6OoCLYvbEFZNXy05y9SPcTSJeQ7ZtYZ6BqOmuPuO+IbliRTdjZcey2sXQvBdQAlN/bRN/7t26utX6S6KDcBmFkGcD3Q3t0vM7POZra/u4+Pf3iSSNnZcPnlsGVL5NjY6/TrRi2R6iWWcwDPAjuBw8PhJcCdcYtIkuLKK+G880pu/GOnG7VEqp9YzgF0dPdfm9m5AO6+zSxa9xxS3WRkwLZtu/+5Y4+F+fN1o5ZIdRZLAthpZvUJbwozs45ApZwDMLMTCXocqwU85e4jKmO+Ur49TeFmQecsjz5aufGISOLFkgBuBd4B2plZNtAfuKiiCzazWsAjwHEEzUpfm9kb7j6rovOWslXk+C0zUxt/kZqizHMAZpYG7EVwN/BFwMtAH3efWAnL7gvMd/cF7r4TeAUYVAnzlSiuvLJiG39QSWaRmqTMIwB3LzKzq939VeCtSl52W2BxxHBxZzO7MLMhhB3QZKoH8D3WowfMqoRjK/0EIjVHLFcBvW9mfzSzdmbWrPhRCcsubV+0tH4HRrp7H3fv07KlbkDeXcV7/ZWx8Qdd6SNSk8RyDuCS8PmqiHEO7FfKtLtjCdAuYnhfYFkF5ykRdr2Lt+KaN9eVPiI1SSx3AneI07K/BjqHtYWWAucAv4nTslLGrnfx7pk6dcAd8vN/HJeRAQ8+WPH4RKTqiOVO4HrAlcDPCPb8PwYed/ftFVmwuxeY2dXAuwSXgT7j7jMrMs9Ul50NF1+864Z7dzVv/uOGXh2yiNRsUfsE/mECs1eBPHbtD2Avdz87zrH9hPoELlu0vnRjoRo+IjXXnvQJXGx/dz8gYvhDM5tWeaFJRRV3xr67G/9jj4UJE+ITk4hUfbFcBfQ/MzuseMDM+gGfxi8k2R3Z2UEH6rFu/Bs0gBdfDNr4tfEXSW2xHAH0Ay4ws+JbgDKB2WaWA7i7945bdFKu4cNh69byp0tPh2efVROPiPwolgRwYtyjkD22aFFp9fp31aABPPGENv4isqtYLgPdw9OKEk8bt+Zz51uzSGvUOeyaMboWLbTxF5GfiuUcgFQx78xYwcAHPmLs/5byy8vXU8oN1LtQ/R4RKU0sTUBSRazK285tb8zk7ZwVdG/TmGcvOpSebZsw4Zmyb/xS/R4RKY2OAKqo7Ozguv60NGjf3rn2zrUcd/8kJsxexdAT92fc1f3p2bYJENy4lZ5e+nzS01W/R0RKV+6NYFVJqtwIVnxpZ+TVPVa7gEPPX8BL9+xDx5YNf/KZFi1KPwpo3hzWrIljsCJS5VXkRjBJsNIu7fSC2qz4oDMdW5Z+xc+6daXPK9p4ERE1AVVBwaWdP7V4cfTLPaO186v9X0SiUQJIksg2/qysYDi/sIhHPpxPrUal19kra2N+111Bxc5IGRlq/xeR6JQAkiCyfIN78Pzby5w+F83j7+/OZcDgFdSvv+tRQHkb88GDYeTIoKibWfA8cqSu/xeR6HQSOAmiVe2s02Qb4z7ZyIk99/6hwJvKMYtIRUU7CawEkARpacGef0lmTlFRBXttFxEpIVoCUBNQEuzbrvSkm5mpjb+IJI4SQIJNnLuKeofNwmoX7DJeJ2xFJNGUABJk/ZadXP/qVC569mva9V3NX+7drhO2IpJUuhEsztyd/8xYwZ/HzWDD1nyuProTVx/TiXrptbj52mRHJyKpTAkgjlZt2s4t42bw7syV9GzbmOcu6UuPfZokOywREUBNQBVW2g1d7s6rkxcz8P6PmDh3NcNO6srrV/bXxl9EqhQdAVRAyaJtCxfCZZc5j374HUtbzKVvh2aMOLMX+5VSvE1EJNmUACqgtKJt27YZX49uyzPvpPObvpmkpenSThGpmpQAKiBaT1sFm+px3mHtExuMiMhu0jmACminG7pEpBpTAthDOUs20uTIObqhS0SqLSWA3bQ9v5C7/zObQY98Qv3uSxl61xbd0CUi1ZLOAZShZEXOi6/NY2LhFL5fs4Vz+7Zj2EndaFI/nRFDkx2piMjuUwKIorRLPG8fWp8uZ7XkpTt7ckSnFskNUESkgpLSBGRmt5nZUjObGj5OTkYcZYnWL+/Wz7tr4y8iNUIyjwAecPd7k7j8MgX98v70ap4lZfTLKyJSnegkcAnuzpvTllG78e73yysiUp0kMwFcbWbTzewZM9sr2kRmNsTMJpvZ5NWrV8c1oJWbtnPZ81P4v5f/R69BC6m3m/3yiohUJ3FLAGY2wcxmlPIYBDwGdAQOBJYD90Wbj7uPdPc+7t6nZcuWcYnV3Xnlq0UMvP8jPpm/muEnd+OrZ/fnqSdNl3iKSI2V9D6BzSwLGO/uPcubNh59Ai9au5VhY6fz2XdrOWy/Zow4szdZLRpU6jJERJIpWp/ASTkJbGZt3H15OHgGMCPRMRQWOc9++j33vjeX9LQ0/npGL845tJ2Kt4lIykjWVUB/M7MDAQdygcsTufC5K/IYOmY60xZv4NiurbjzjJ60aVI/kSGIiCRdUhKAu5+fjOXuLCji0YnzeeTD+TSql85D5x7Eab3bYKa9fhFJPSlzJ/DUxRu4YfR05q7MY9CB+3DraT1o1qBOssMSEUmalEgAD3/wLQ9MmEerRvV4+sI+HNutdbJDEhFJupRIAJnNMzinbybDTupK43rpyQ5HRKRKSIkEMOjAtgw6sG2ywxARqVJUCkJEJEUpAYiIpCglABGRFKUEICKSopQARERSlBKAiEiKUgIQEUlRSgAiIikq6f0B7A4zWw0sTNDiWgBrErSsZEqV9YTUWddUWU/Qusaqvbv/pEetapUAEsnMJpfWgUJNkyrrCamzrqmynqB1rSg1AYmIpCglABGRFKUEEN3IZAeQIKmynpA665oq6wla1wrROQARkRSlIwARkRSlBCAikqKUAMpgZreZ2VIzmxo+Tk52TJXJzE40s7lmNt/MhiU7nngys1wzywl/x8nJjqeymNkzZrbKzGZEjGtmZu+b2bfh817JjLGyRFnXGvc/ambtzOxDM5ttZjPN7NpwfKX/rkoA5XvA3Q8MH28nO5jKYma1gEeAk4DuwLlm1j25UcXd0eHvWJOuGx8FnFhi3DDgA3fvDHwQDtcEo/jpukLN+x8tAP7g7t2Aw4Crwv/NSv9dlQBSV19gvrsvcPedwCvAoCTHJLvJ3ScB60qMHgQ8F75+DvhFImOKlyjrWuO4+3J3/yZ8nQfMBtoSh99VCaB8V5vZ9PDws0YcSofaAosjhpeE42oqB94zsylmNiTZwcRZa3dfDsHGBGiV5Hjirab+j2JmWcBBwJfE4XdN+QRgZhPMbEYpj0HAY0BH4EBgOXBfMmOtZFbKuJp8TXB/dz+YoMnrKjM7KtkBSaWosf+jZtYQGAP83t03xWMZteMx0+rE3QfGMp2ZPQmMj3M4ibQEaBcxvC+wLEmxxJ27LwufV5nZawRNYJOSG1XcrDSzNu6+3MzaAKuSHVC8uPvK4tc16X/UzNIJNv7Z7j42HF3pv2vKHwGUJfySi50BzIg2bTX0NdDZzDqYWR3gHOCNJMcUF2bWwMwaFb8Gjqdm/ZYlvQFcGL6+EBiXxFjiqib+j5qZAU8Ds939/oi3Kv131Z3AZTCzFwgOLR3IBS4vboOrCcJL5v4B1AKecfe7khtRfJjZfsBr4WBt4KWasq5m9jIwgKBU8ErgVuB14FUgE1gEnO3u1f7kaZR1HUAN+x81s58BHwM5QFE4+iaC8wCV+rsqAYiIpCg1AYmIpCglABGRFKUEICKSopQARERSlBKAiEiKUgKQKs/MBpjZERWcx+bdmHaUmf2yIsurLGb22W5OX2Vil6pPCUCqgwFAhRJAdeXuKbnekhhKAJIUZvZ6WJhtZmRxtrCPgm/MbJqZfRAWw7oCuC6s935kyb3c4r17M2sYfuabsPZ/udVNzeyCsJDYtPDGv2JHmdlnZrageFnR5m9mWWHt9ifD9XnPzOqH7x0azv9zM/t7cS17M6sVDn8dvn95lPiK122AmU00s9FmNsfMssM7Rstat2PN7H9hrM+YWd1w/AgzmxUu995w3NlhDaxpZlZTS2RISe6uhx4JfwDNwuf6BLfvNwdaElQo7VBimtuAP0Z8dhTwy4jhzeFzbaBx+LoFMJ8fb3bcXEoMPYC5QIsSyxsF/JtgB6k7QdnsqPMHsghquB8YvvcqcF74egZwRPh6BDAjfD0EuDl8XReYXLzeJWIsXrcBwEaCmk1pwOfAz0qZfhTwS6Be+F12Ccc/D/weaBauc/H30jR8zgHaRo7To+Y/dAQgyXKNmU0DviAoSteZoPOLSe7+PYDv/m3uBvzVzKYDEwjKW7cuY/pjgNHuvqaU5b3u7kXuPitiHmXN/3t3nxq+ngJkmVlToJG7F7fjvxQx/+OBC8xsKsEt/s0JvoOyfOXuS9y9CJhKkHii2T+MaV44/BxwFLAJ2A48ZWZnAlvD9z8FRpnZZQSlQSQFpHw1UEk8MxsADAQOd/etZjaRYI/ViK0kdQFh82XYDFInHD+Y4CjiEHfPN7PccL5RQyljeTtKTFfe/COnLyQ4simricaA/3P3d8uYpqyYCin7/7fUZbt7gZn1BY4lKAB4NXCMu19hZv2AU4CpZnagu6/djdikGtIRgCRDE2B9uPHvSrDnD0Gzxs/NrAMEfaCG4/OARhGfzwUOCV8PAtIj5rsq3DgfDbQvJ44PgF+ZWfMSyysr7pjn7+7rgTwzK16/cyLefhf4XVj2FzPrElYqrSxzCI5COoXD5wMfWVBjvokHXSf+nqCQGmbW0d2/dPc/A2vYtVS41FA6ApBkeAe4ImxKmUvQDIS7rw5PCI81szSCeufHAW8Co8OTrv8HPAmMM7OvCDbiW8L5ZgNvWtDp+1SCjWBU7j7TzO4i2DAWAv8DLirjI7s1/9ClwJNmtgWYSNCOD/AUQRPON+FRzGoqsetGd99uZhcD/zaz2gTlvx8nOAcwzsyKj7iuCz/ydzPrHI77AJhWWbFI1aVqoCJxZGYN3b34Sp5hQBt3vzbJYYkAOgIQibdTzOxGgv+1hZR9hCGSUDoCEBFJUToJLCKSopQARERSlBKAiEiKUgIQEUlRSgAiIinq/wFMmmEjF9eWjQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "x = np.linspace(-6, 20)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.plot(act_infl, pred_infl, 'o', color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on Cora dataset')\n",
    "plt.title('Influence function on Iris dataset')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "90f2deb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.961251667359122, pvalue=4.191598756724728e-79)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(act_infl, pred_infl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 353,
   "id": "389ee34b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-4.747267353047391"
      ]
     },
     "execution_count": 353,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(act_infl), np.max(pred_infl)\n",
    "np.min(act_infl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 355,
   "id": "195721df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "664.2333037358865"
      ]
     },
     "execution_count": 355,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ori_val_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd962e6b",
   "metadata": {},
   "source": [
    "#### Edge perturbation measure the group influence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "id": "404213f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = load_iris()\n",
    "x, y = iris.data, iris.target\n",
    "\n",
    "# scaler = preprocessing.StandardScaler().fit(x)\n",
    "# x = scaler.transform(x)\n",
    "\n",
    "train_x, test_x, train_y,test_y = train_test_split(x,y,test_size=0.2,random_state=123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 321,
   "id": "eca8b879",
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_index = np.arange(10, 15, 1)\n",
    "perturb_index = np.arange(0, 5, 1)\n",
    "perturb_ratio = 0.1\n",
    "train_x_new = train_x[perturb_index] + train_x[perturb_index] * perturb_ratio\n",
    "train_y_new = train_y[perturb_index]\n",
    "train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "# del train_x\n",
    "# del train_x_new\n",
    "# del train_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "id": "770f2433",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y_orig.reshape(-1, 1))\n",
    "one_hot_labels_train_orig = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "id": "b079b19a",
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_orig = np.ones(len(train_x_orig))\n",
    "weight_1 = np.ones(len(train_x_orig))\n",
    "weight_1[len(train_x_orig) - len(perturb_index):] = 0 # newly perturbed index set to 0\n",
    "weight_2 = np.ones(len(train_x_orig))\n",
    "weight_2[perturb_index] = 0 # perturbed index set to 0\n",
    "assert (np.allclose(train_x_orig[weight_1 == 1], train_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "0922e904",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "lr_origin.fit(train_x_orig, train_y_orig, sample_weight=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "id": "1efa9d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_test_y_origin = test_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "logits_train_y_origin = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "id": "d3e199d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_test_y_origin, one_hot_labels_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "id": "ad835bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y_origin, axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "id": "61f061bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "id": "af1e41ba",
   "metadata": {},
   "outputs": [
    {
     "ename": "LinAlgError",
     "evalue": "15-th leading minor of the array is not positive definite",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mLinAlgError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_7789/4139505529.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mloss_grad_hvp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr_origin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_inv_hvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loss_total_grad_orig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;31m# loss_grad_hvp = lr_origin.get_inv_hvp(hess, val_loss_total_grad_orig, cho = False)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/Projects/Project6_influence_function/graph_influence_function/model.py\u001b[0m in \u001b[0;36mget_inv_hvp\u001b[0;34m(hessian, vectors, cho)\u001b[0m\n\u001b[1;32m     59\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_inv_hvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhessian\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectors\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcho\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_eps\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1e-15\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mcho\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mcho_solve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcho_factor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhessian\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     62\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0muse_eps\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     63\u001b[0m             \u001b[0mhess_inv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhessian\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/scipy/linalg/decomp_cholesky.py\u001b[0m in \u001b[0;36mcho_factor\u001b[0;34m(a, lower, overwrite_a, check_finite)\u001b[0m\n\u001b[1;32m    150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    151\u001b[0m     \"\"\"\n\u001b[0;32m--> 152\u001b[0;31m     c, lower = _cholesky(a, lower=lower, overwrite_a=overwrite_a, clean=False,\n\u001b[0m\u001b[1;32m    153\u001b[0m                          check_finite=check_finite)\n\u001b[1;32m    154\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/scipy/linalg/decomp_cholesky.py\u001b[0m in \u001b[0;36m_cholesky\u001b[0;34m(a, lower, overwrite_a, clean, check_finite)\u001b[0m\n\u001b[1;32m     35\u001b[0m     \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpotrf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moverwrite_a\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moverwrite_a\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclean\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclean\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m         raise LinAlgError(\"%d-th leading minor of the array is not positive \"\n\u001b[0m\u001b[1;32m     38\u001b[0m                           \"definite\" % info)\n\u001b[1;32m     39\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mLinAlgError\u001b[0m: 15-th leading minor of the array is not positive definite"
     ]
    }
   ],
   "source": [
    "train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                                    logits_train_y_origin, \n",
    "                                                    one_hot_labels_train_orig)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(test_x, \n",
    "                                                                    logits_test_y_origin,\n",
    "                                                                    one_hot_labels_test)\n",
    "\n",
    "hessian_no_reg, hess, hessian_reg_term = lr_origin.hess(train_x_orig, logits_train_y_origin)\n",
    "\n",
    "\n",
    "loss_grad_hvp = lr_origin.get_inv_hvp(hess, val_loss_total_grad_orig.T)\n",
    "# loss_grad_hvp = lr_origin.get_inv_hvp(hess, val_loss_total_grad_orig, cho = False)\n",
    "\n",
    "pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 330,
   "id": "ed83b836",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_infl = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "id": "521b329c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "\n",
    "assert(np.allclose(train_x_delete_1, train_x))\n",
    "assert(np.allclose(train_y_delete_1, train_y))\n",
    "\n",
    "lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "logits_test_y_new_1 = test_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "\n",
    "new_ori_val_loss, new_ave_ori_val_loss = lr_new_1.log_loss(logits_test_y_new_1, one_hot_labels_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 332,
   "id": "568a42b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_infl.append(new_ori_val_loss - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 333,
   "id": "4eba8934",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.04573299716190893"
      ]
     },
     "execution_count": 333,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc_infl\n",
    "np.sum(pred_infl[120:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 334,
   "id": "e61dab8c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.05185286378919951]"
      ]
     },
     "execution_count": 334,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc_infl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 295,
   "id": "df5db270",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]])\n",
    "# b = np.array([1, 1, 1, 1])\n",
    "\n",
    "A = hess\n",
    "b = val_loss_total_grad_orig.T\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "id": "d107df48",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 1e-15\n",
    "eps_mat = np.diag(np.full(len(A), eps))\n",
    "# cho_factor(A, lower = False)\n",
    "d1 = inv(A) @ b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "id": "25bf2014",
   "metadata": {},
   "outputs": [],
   "source": [
    "d2 = cho_solve(cho_factor(A + eps_mat), b)\n",
    "d3 = cho_solve(cho_factor(A), b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 302,
   "id": "8b5691e3",
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_7789/1873855494.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "assert(np.allclose(d1, d2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 305,
   "id": "a34de0e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 0.0080887 ],\n",
       "        [-0.00856946],\n",
       "        [ 0.11161323],\n",
       "        [ 0.07652811],\n",
       "        [-0.05295199],\n",
       "        [ 0.14766832],\n",
       "        [-0.00116725],\n",
       "        [ 0.13448772],\n",
       "        [ 0.0448633 ],\n",
       "        [-0.13909886],\n",
       "        [-0.11044598],\n",
       "        [-0.21101583],\n",
       "        [-0.51772567],\n",
       "        [-0.28079501],\n",
       "        [ 0.88888889]]),\n",
       " array([[ 0.0080887 ],\n",
       "        [-0.00856946],\n",
       "        [ 0.11161323],\n",
       "        [ 0.07652811],\n",
       "        [-0.05295199],\n",
       "        [ 0.14766832],\n",
       "        [-0.00116725],\n",
       "        [ 0.13448772],\n",
       "        [ 0.0448633 ],\n",
       "        [-0.13909886],\n",
       "        [-0.11044598],\n",
       "        [-0.21101583],\n",
       "        [-0.56811899],\n",
       "        [-0.33118833],\n",
       "        [ 0.83849558]]))"
      ]
     },
     "execution_count": 305,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2, d3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 345,
   "id": "f056a018",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4082, 0.4472, 0.4472,  ..., 0.7071, 0.7071, 0.7071])"
      ]
     },
     "execution_count": 345,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "torch.pow(degs, -0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22b64d77",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
