{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fd16a5f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import math\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import torch.nn.functional as F\n",
    "import tensorflow.compat.v1 as tf\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from generate_mid_layer_feature import FeatureExtraction\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "from sklearn.utils import check_array\n",
    "from dataset import load_graph_dataset\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e5c1c14a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/zizhang/Desktop/Projects/Project6_influence_function/graph_influence_function/experiments'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d1301c9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea5fe5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_set = 'cora'\n",
    "# # l2_term = 0.01\n",
    "# l2_term = 0.01\n",
    "# # batch_edges = 1\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'pubmed'\n",
    "# l2_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'citeseer'\n",
    "# l2_term = 0.003\n",
    "# num_layer = 2\n",
    "\n",
    "data_set = 'reddit'\n",
    "l2_term = 0.99\n",
    "num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5ceb9ec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0a081a58",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b4901190",
   "metadata": {},
   "outputs": [],
   "source": [
    "del norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3daacab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "# test_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "# test_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "val_y = labels[val_mask].numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b0c250ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\"\"\" Train Logistic Regression \"\"\"\n",
    "# lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "# lr.fit(train_x, train_y, sample_weight=None, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fef5d5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask, seed_val=10):\n",
    "    torch.manual_seed(seed_val)\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        \n",
    "#         print(f_index)\n",
    "#         random_index = torch.randint(0, len(to_index_list), (1,))[0]\n",
    "        for to_index_e in to_index_list:\n",
    "#             print(to_index_list)\n",
    "#             print(to_index_e.item())\n",
    "#             j = to_index_list[to_index_e[0]]\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd098e5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e3124cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "# f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "f_l, t_l = from_indexes, to_indexes\n",
    "\n",
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# df = pd.read_csv('result_data/' + data_set + '_edge_influence.csv', header = None)\n",
    "# df = df.loc[df[0] != 0]\n",
    "# f_l = torch.tensor(df[2].values.astype(int))\n",
    "# t_l = torch.tensor(df[3].values.astype(int))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4f25efaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "random.seed(42)\n",
    "num_rand = 10000\n",
    "rand_idx = random.sample(range(0, len(f_l)), 10000)\n",
    "f_l = f_l[rand_idx]\n",
    "t_l = t_l[rand_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d3ffca7",
   "metadata": {},
   "source": [
    "##### Step by Step to get the hessian vector product, to prevent out of memory. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3f5cb752",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# split_side = int(len(f_l) / batch_edges)\n",
    "# total_edges = np.arange(len(f_l))\n",
    "# np.random.shuffle(total_edges)\n",
    "# new_index = np.array_split(total_edges, split_side)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4a27c429",
   "metadata": {},
   "outputs": [],
   "source": [
    "# changed_list = []\n",
    "# changed_list_change = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d995df0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# train the original data\n",
    "# calculate the hessian matrix\n",
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "88672a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # pd.DataFrame(np.array(val_loss_total_grad_orig.T)).to_csv('reddit/val_loss_total_grad_orig.csv', index = False)\n",
    "\n",
    "# pd.DataFrame([ori_val_loss]).to_csv('reddit/ori_val_loss.csv', index = False)\n",
    "# pd.read_csv('reddit/ori_val_loss.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "bb40324c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hess = cp.asarray(pd.read_csv('reddit/hess.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "75c82061",
   "metadata": {},
   "outputs": [],
   "source": [
    "# val_loss_total_grad_orig = cp.asarray(pd.read_csv('reddit/val_loss_total_grad_orig.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "cb5b8c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hess = cp.asarray(pd.read_csv('reddit/hess.csv'))\n",
    "\n",
    "\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig, cholskey=True)\n",
    "# loss_grad_hvp = cp.asnumpy(loss_grad_hvp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "cb3f6426",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.DataFrame(loss_grad_hvp).to_csv('reddit/loss_grad_hvp.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "4db0f3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "\n",
    "# ###### write out hessian matrix\n",
    "\n",
    "# hess = cp.asnumpy(hess)\n",
    "\n",
    "# pd.DataFrame(hess).to_csv(\"reddit/hess.csv\", index = False)\n",
    "\n",
    "# ######\n",
    "\n",
    "\n",
    "# # loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "# # loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "# del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6068d1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_grad_hvp = np.array(pd.read_csv('reddit/loss_grad_hvp.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5199878c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # convert to one-hot labels\n",
    "# enc = OneHotEncoder(handle_unknown='ignore')\n",
    "# enc.fit(train_y.reshape(-1, 1))\n",
    "# one_hot_labels_train_orig = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "# # train the original data\n",
    "# # calculate the hessian matrix\n",
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "# logits_test_y_origin = test_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "# ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_test_y_origin, one_hot_labels_test, l2_reg=True)\n",
    "\n",
    "# # numpy_theoritic_loss = log_loss(test_y, softmax(logits_test_y_origin, axis=1))\n",
    "\n",
    "# # assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "# val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(test_x, \n",
    "#                                                                     logits_test_y_origin,\n",
    "#                                                                     one_hot_labels_test, l2_reg = True)\n",
    "\n",
    "# hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "# loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "# loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "# del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "41d604b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                 | 0/10000 [00:34<?, ?it/s]\n"
     ]
    },
    {
     "ename": "MemoryError",
     "evalue": "Unable to allocate 35.7 GiB for an array with shape (194205, 41, 602) and data type float64",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mMemoryError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_7990/4223083987.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     46\u001b[0m     \u001b[0mlogits_train_y_origin_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_x_orig\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mlr_origin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoef_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlr_origin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintercept_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m     train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n\u001b[0m\u001b[1;32m     49\u001b[0m                                             \u001b[0mlogits_train_y_origin_0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m                                             one_hot_labels_train_0, l2_reg = True)\n",
      "\u001b[0;32m~/Desktop/Projects/Project6_influence_function/graph_influence_function/experiments/../model_softmax.py\u001b[0m in \u001b[0;36mgrad\u001b[0;34m(self, x, logits, one_hot_labels, sample_weight, l2_reg, fit_intercept)\u001b[0m\n\u001b[1;32m    120\u001b[0m         \u001b[0mexpand_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpand_dims\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# (n, 1, num_dimension)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m         \u001b[0mindiv_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpand_factor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpand_x\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# (n, num_classes, num_dimension)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    123\u001b[0m         \u001b[0mindiv_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindiv_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mD\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# (n, num_classes * num_dimension)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mMemoryError\u001b[0m: Unable to allocate 35.7 GiB for an array with shape (194205, 41, 602) and data type float64"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "# for k in tqdm(range(len(train_node_idx))):\n",
    "for k in tqdm(range(len(f_l))):\n",
    "# for k in tqdm(range(len(new_index))):\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[new_index[k]], to_index=t_l[new_index[k]])\n",
    "#     eis.remove_edges_sgc_from_influence()\n",
    "    eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "    eis.remove_edges_sgc_from_influence()\n",
    "    feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "#     node_id = train_node_idx.numpy()[k]\n",
    "#     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "#     nis.remove_edges_sgc()\n",
    "#     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    if extra_index_train == []:\n",
    "        predict_influence_1.append(0.0)\n",
    "        acctual_influence_1.append(0.0)\n",
    "        continue\n",
    "    \n",
    "    feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[perturb_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "#     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "#     weight_1 = np.ones(len(train_x_orig))\n",
    "#     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "#     weight_2 = np.ones(len(train_x_orig))\n",
    "#     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "#     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "#     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "#     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "#     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "#     assert(np.allclose(train_x_delete_1, train_x))\n",
    "#     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "#     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "#     logits_val_y_new_1 = val_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "#     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_val_y_new_1, one_hot_labels_val)\n",
    "    \n",
    "    \n",
    "    \n",
    "    lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "    train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "    train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "    lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "    logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "    new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "    acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "804d2f40",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([114848857])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape, train_x_new.shape, train_x_orig.shape\n",
    "graph.edges()[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4efbf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    " \n",
    "counts = numpy.array([10, 20, 30, 40, 50, 60, 70, 80])\n",
    "memory_in_bytes = counts.itemsize * len(counts)\n",
    "print(\"This array uses\", memory_in_bytes, \"bytes.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "5eb16e45",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4817255025"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "194205 * 41 * 605\n",
    "size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078dd1a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# acctual_influence_1 = []\n",
    "# acctual_influence_2 = []\n",
    "\n",
    "# predict_influence_1 = []\n",
    "# predict_influence_2 = []\n",
    "\n",
    "# # for k in tqdm(range(len(train_node_idx))):\n",
    "# for k in tqdm(range(len(f_l))):\n",
    "\n",
    "#     eis = EdgeInfluenceSGC(graph=graph, feature=feat, from_index=f_l[k], to_index=t_l[k])\n",
    "#     eis.remove_edges_sgc()\n",
    "#     feat_removed1 = eis.calculate_modified_features()\n",
    "    \n",
    "# #     node_id = train_node_idx.numpy()[k]\n",
    "# #     nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "# #     nis.remove_edges_sgc()\n",
    "# #     feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "#     extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "#     extra_index_train = torch.tensor(\n",
    "#         [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "#     extra_index_train_in_train = [\n",
    "#         np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "#     if extra_index_train != []:\n",
    "#         predict_influence_1.append(0.0)\n",
    "#         acctual_influence_1.append(0.0)\n",
    "#         continue\n",
    "    \n",
    "#     feat_to_be_added = feat_removed1[extra_index_train].numpy()\n",
    "#     perturb_index = extra_index_train_in_train\n",
    "    \n",
    "    \n",
    "#     train_x_new = feat_to_be_added\n",
    "#     train_y_new = train_y[perturb_index]\n",
    "    \n",
    "#     train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "#     train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "#     one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "#     logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "#     train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "#                                             logits_train_y_origin_0, \n",
    "#                                             one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "    \n",
    "#     pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "# #     weight_orig = np.ones(len(train_x_orig)) # 1...1...11\n",
    "    \n",
    "# #     weight_1 = np.ones(len(train_x_orig))\n",
    "# #     weight_1[len(train_x_orig) - len(perturb_index):] = 0 # 1...1...10\n",
    "    \n",
    "# #     weight_2 = np.ones(len(train_x_orig))\n",
    "# #     weight_2[len(train_x_orig) - len(perturb_index):] = 0 \n",
    "# #     weight_2[perturb_index] = 0 # 1...0...10\n",
    "    \n",
    "#     weight_3 = np.ones(len(train_x_orig))\n",
    "#     weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "# #     lr_new_1 = SimplifiedGraphNeuralNetwork(l2_reg=1.0, fit_intercept=True)\n",
    "# #     train_x_delete_1 = train_x_orig[weight_1 == 1]\n",
    "# #     train_y_delete_1 = train_y_orig[weight_1 == 1]\n",
    "    \n",
    "# #     assert(np.allclose(train_x_delete_1, train_x))\n",
    "# #     assert(np.allclose(train_y_delete_1, train_y))\n",
    "    \n",
    "# #     lr_new_1.fit(train_x_delete_1, train_y_delete_1)\n",
    "# #     logits_test_y_new_1 = test_x @ lr_new_1.model.coef_.T + lr_new_1.model.intercept_\n",
    "# #     new_ori_val_loss_1, _ = lr_new_1.log_loss(logits_test_y_new_1, one_hot_labels_test)\n",
    "    \n",
    "    \n",
    "    \n",
    "#     lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "#     train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "#     train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "#     lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "#     logits_test_y_new_2 = test_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "#     new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_test_y_new_2, one_hot_labels_test, l2_reg = True)\n",
    "    \n",
    "#     predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "#     acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b4e830",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 10, color='blue')\n",
    "plt.xlabel('actual change in loss')\n",
    "plt.ylabel('predicted change in loss')\n",
    "# plt.title('Influence function on edges of Cora dataset perturb edges:' + str(batch_edges))\n",
    "# plt.title('Influence function on Complete Node of Citeseer dataset')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "# plt.title('Influence function on Iris dataset')\n",
    "plt.show()\n",
    "print('Number of edges in consider %d'% len(f_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc429de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_index_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "097e637e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea5dc610",
   "metadata": {},
   "outputs": [],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a2a0406",
   "metadata": {},
   "outputs": [],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74de54c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([acctual_influence_1, predict_influence_1, f_l.numpy().astype(int), t_l.numpy().astype(int)]).T\n",
    "df.columns = ['acctual_influence', 'predict_influence', 'from_edges', 'to_edges']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "269550ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('result_data/' + data_set + '_edge_influence_001.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020837da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv('result_data/pubmed_edge_influence.csv', header = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc3963b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "# x = np.linspace(-0.2, 0.15)\n",
    "plt.plot(x, x)\n",
    "plt.scatter(df[0], df[1], s = 10, color='blue')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07163977",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
