{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "99703af2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from dataset import load_graph_dataset\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8d3bee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_set = 'cora'\n",
    "# # l2_term = 0.01\n",
    "# l2_term = 0.1\n",
    "# # batch_edges = 1\n",
    "# num_layer = 2\n",
    "\n",
    "# data_set = 'pubmed'\n",
    "# l2_term = 0.004\n",
    "# num_layer = 2\n",
    "\n",
    "data_set = 'citeseer'\n",
    "l2_term = 0.003\n",
    "num_layer = 2\n",
    "\n",
    "# data_set = 'reddit'\n",
    "# l2_term = 0.99\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aaedbf3",
   "metadata": {},
   "source": [
    "##### 1, load data, convert to one hot encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9e0d3194",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(data_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "55bf908c",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0f81969e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "val_y = labels[val_mask].numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5234d1cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dca8cc84",
   "metadata": {},
   "source": [
    "##### 2, Remove a single node, remove the propogated node feature as well as it connected edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6194a67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"retrieve all edges connected to traing nodes\"\"\"\n",
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask):\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        for to_index_e in to_index_list:\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95f1b867",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 60/60 [00:00<00:00, 9909.76it/s]\n"
     ]
    }
   ],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "\n",
    "acctual_influence_node_features = []\n",
    "acctual_influence_edges = []\n",
    "\n",
    "predict_influence_node_features = []\n",
    "predict_influence_edges = []\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ab82ca8",
   "metadata": {},
   "source": [
    "##### 2.1 train the original data, calculate the hessian matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "abe75622",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "# # set l2_reg to False, verify the correctness of calculations\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "982d091f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "time_infl = []\n",
    "time_retrain = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "57aa5a6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████| 60/60 [00:04<00:00, 13.27it/s]\n"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "for k in tqdm(range(len(train_node_idx))):\n",
    "    \n",
    "    node_id = train_node_idx.numpy()[k]\n",
    "    nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "    \n",
    "    # 2, remove the edges, calculate the perturbated feature\n",
    "    nis.remove_edges_sgc()\n",
    "    feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    # 1, we need to remove the changed node feature from the perturbated feature, \n",
    "    # let it not added to the original feature\n",
    "    \n",
    "    \n",
    "    \"\"\"modified node features\"\"\"\n",
    "    extra_index_train_remove_node = extra_index_train.copy()\n",
    "    relative_node_id = np.where(extra_index_train_remove_node == node_id)[0]\n",
    "    extra_index_train_remove_node = np.delete(extra_index_train_remove_node, relative_node_id)\n",
    "    feat_to_be_added = feat_removed1[extra_index_train_remove_node].numpy()\n",
    "    \n",
    "    \"\"\"index corresponding to modified node features\"\"\"\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    added_index = perturb_index.copy()\n",
    "    added_index.remove(k)\n",
    "\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[added_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "\n",
    "    start_time = time.time()\n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    time_infl.append(time.time() - start_time)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "    train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "    train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "    start_time = time.time()\n",
    "    lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "    time_retrain.append(time.time() - start_time)\n",
    "    \n",
    "    logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "    new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "    acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "70561438",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASoAAAEQCAYAAAAH2znkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnUklEQVR4nO3deVxU9f4/8NfMACouFzCgQUnNXFA0FbQUw0AUchstSTLzYffr8nPPHi7U7eGClGG2eLGy61L3Xs2FEkwh5aKmZuWWuYKSiSIMu15RFHDm/P7gMjGyHYZZzhlez8ejx2PmnDMz79PEq8/5zPl8PgpBEAQQEUmY0tYFEBHVh0FFRJLHoCIiyWNQEZHkMaiISPIYVEQkeZIKqpiYGAQHB6Nbt264cuWKYfu1a9cwYcIEhIaGYsKECcjIyKjx9TqdDitWrEBISAiGDRuGuLg4K1VORJYkqaAaOnQotm7dinbt2hltX7ZsGSZOnIj9+/dj4sSJWLp0aY2v37NnD27cuIHk5GTs2LEDsbGxuHnzpjVKJyILklRQ+fv7Q61WG20rLCzEpUuXMGrUKADAqFGjcOnSJRQVFVV7fVJSEsLDw6FUKuHm5oaQkBDs27fPKrUTkeU42LqA+mi1Wnh6ekKlUgEAVCoVPDw8oNVq4ebmVu1YLy8vw3O1Wo2cnBxRn6PX63Hv3j04OjpCoVCY7wSIyEAQBJSXl6Nly5ZQKsW3kyQfVNZy7949o34xIrKcrl27onXr1qKPl3xQqdVq5ObmQqfTQaVSQafTIS8vr9olYuWx2dnZ6N27N4DqLay6ODo6Aqj4F+jk5GS+E5CICxcuwNfX19ZlmB3PS/p0Oh3y8/Oh0+nwl7/8BRkZGYa/N7EkH1Rt27aFj48P9u7dC41Gg71798LHx6faZR8AhIWFIS4uDsOHD8ft27eRkpKCrVu3ivqcyss9JycnNGvWzKznIBU8L3mxh/PS6XQoKCgAAHh5eRn+zhravSKpzvTo6GgEBgYiJycHr7/+OkaOHAkAWL58ObZs2YLQ0FBs2bIFK1asMLxm2rRpOH/+PABAo9Ggffv2GD58OF5++WXMnj0b3t7eNjkXoqZOp9MhJycHDx8+hKenJ5o3b27yeyk4zUuF0tJSQ3PbHv5P9qjTp0/Dz8/P1mWYHc9LmmoLKVP/ziTVoiIi+TNnS6oSg4qIzMYSIQUwqIjITCwVUgCDiojMwJIhBTCoiKiRLB1SAIOKiBrBGiEFMKiIyETWCimAQUVEJrBmSAEMKiJqIGuHFMCgIqIGsEVIAQwqIhLJViEFMKiISARbhhTAoCKietg6pAAGFRHVQQohBTCoiKgWUgkpgEFFRDWQUkgBDCoieoTUQgpgUBFRFVIMKUAGizsAwM2bNzF79mzD8+LiYty9excnTpwwOi42NhZff/01PDw8AAD9+vXDsmXLrForkVxJNaQAmQRV+/btsXv3bsPzd999FzqdrsZjx44diyVLllirNCK7IOWQAmR46VdWVoY9e/bgpZdesnUpRHZB6iEFyDCoDh48CE9PT/Ts2bPG/YmJiRg9ejT++te/4syZM1aujkhe5BBSgAyXy5o2bRqee+45TJ48udq+/Px8uLi4wNHREceOHcPChQuRlJQEV1fXet+3chkfIrK8hi6XJYs+qkq5ubk4efIkVq9eXeN+d3d3w+OAgACo1Wqkp6djwIABoj+D6/rJC8/LNLZqSZnaIJDVpV98fDyGDBlSawspNzfX8Dg1NRVZWVno1KmTtcojkgW5XO5VJasWVXx8PP72t78ZbZs2bRrmzZuHXr164aOPPsLFixehVCrh6OiI1atXG7WyiJo6OYYUILOg2r9/f7VtGzZsMDyOiYmxZjlEsiLXkAJkdulHRKaRc0gBDCoiuyf3kAIYVER2zR5CCmBQEdktewkpgEFFZJfsKaQABhWR3bG3kAIYVER2xR5DCmBQEdkNew0pgEFFZBfsOaQABhWR7Nl7SAEMKiJZawohBTCoiGSrqYQUwKAikqWmFFIAg4pIdppaSAEMKiJZaYohBTCoiGSjqYYUwKAikoWmHFIAg4pI8pp6SAEMKiJJY0hVkM2c6cHBwXBycjIsZbVw4UI899xzRsfodDpER0fj6NGjUCgUmD59OsLDw21RLlGjMaT+JJugAoC///3v6Nq1a6379+zZgxs3biA5ORm3b9/G2LFjMXDgQLRv396KVRKZB0PqT3Z16ZeUlITw8HAolUq4ubkhJCQE+/bts3VZRA2i0+kAgCFVhaxaVAsXLoQgCPDz88Obb76JNm3aGO3XarXw8vIyPFer1cjJyWnQZ9jzsu6nT5+2dQkWYa/nVVBQgIKCAluXIQmyCaqtW7dCrVajrKwM7777LqKiorBmzRqzfw6XdJcXezqvqn1SBQUFdnNeVdn9ku5qtRoA4OTkhIkTJ+LXX3+t8Zjs7GzDc61Wi8cff9xqNRKZ6tGOczImi6AqKSlBcXExAEAQBCQlJcHHx6facWFhYYiLi4Ner0dRURFSUlIQGhpq7XKJGoS/7tVPFpd+hYWFmDt3LnQ6HfR6PTp37oxly5YBAKZNm4Z58+ahV69e0Gg0OHv2LIYPHw4AmD17Nry9vW1ZOlGdGFLiyCKovL29kZCQUOO+DRs2GB6rVCqsWLHCSlURNQ5DSjxZXPoR2RuGVMMwqIisjCHVcAwqIitiSJmGQUVkJQwp0zGoiKyAIdU4DCoiC2NINZ6ooBIEATt37sTkyZMxevRoAMDJkyeRlJRk0eKI5I4hZR6igmrt2rX45ptvMGHCBGi1WgDA448/jo0bN1q0OCI5Y0iZj6igio+Px/r16zFy5EgoFAoAQPv27ZGZmWnR4ojkiiFlXqKCSqfToWXLlgBgCKp79+7B2dnZcpURyRRDyvxEBdWQIUOwatUqlJWVAajos1q7di2CgoIsWhyR3DCkLENUUL311lvIy8uDn58fiouL0bdvX2RnZ2PhwoWWro9INhhSliNqUHKrVq3w2WefobCwEFlZWVCr1XB3d7d0bUSywZCyLFFB9eOPP6Jdu3bo1KkT2rZtCwD4448/oNVqERAQYNECiaSOIWV5oi79oqKiDJ3plVq2bImoqCiLFEUkFwwp6xAVVIWFhfDw8DDa5uHhgfz8fIsURSQHDCnrERVU3t7e+Pnnn422HT9+nOvlUZPFkLIuUX1Uc+bMwdy5czF+/Hh4e3sjMzMTu3btwnvvvWfp+ogkhyFlfaKCKiQkBJs3b8Y333yDw4cPG4bP9O7d29L1AQBu3bqFxYsX48aNG3ByckKHDh0QFRUFNzc3o+NiY2Px9ddfGy5T+/XrZ5hbncgcGFK2IXrO9N69e1stmB6lUCgwdepUPPPMMwCAmJgYrFmzpsYW3dixY7FkyRJrl0hNAEPKdkQFVVlZGeLj45GamoqSkhKjfatXr7ZIYVW5uLgYQgoA+vTpg23btln8c4kqMaRsS1RQRUZGIi0tDUFBQXjssccsXVOd9Ho9tm3bhuDg4Br3JyYm4scff4S7uzvmzp2Lvn37Nuj9uaS7/Fj7vBqyzLpeL+DH1GLczC9De3cnDO7RGsr/jZetj71+X6ZQCIIg1HdQ//79ceDAAbRp08YaNdVpxYoVyM3Nxbp166BUGv9omZ+fDxcXFzg6OuLYsWNYuHAhkpKS4OrqWu/7Vi41zSXd5cXS5yWmJaXTC/j2YDrSrhehewc3jA/uAqWyIoy2/+cytu5LMxzb2tkRmsDOCB/a1XBMTez1+zL170xUi0qtVhsGJNtSTEwMrl+/jvXr11cLKQBGw3oCAgKgVquRnp6OAQMGWLNMshP1hVRlQB08lYms/LsAgJOXcgEAL4d0BQD8cPqm0WuKS8qxZV8aFAqF4Zh666gjCJsKUUE1duxYzJo1C5MnTzYMoak0cOBAixT2qI8//hgXLlzAP/7xDzg5OdV4TG5uLjw9PQEAqampyMrKQqdOnaxSH9kXMS2pbw+m49/fp1bbnna9qN73F3NMTZ/zaBA2FaKCasuWLQCAjz76yGi7QqHAgQMHzF/VI9LT07F+/Xp07NgRERERACom7vv000+NlnT/6KOPcPHiRSiVSjg6OmL16tUcPE0NJrbjvLaw6d7hz9tmgvzaY0uVS7+ajqnPo5/TkJCzF6KC6uDBg5auo05dunTB5cuXa9xXdUn3mJgYa5VEdkrs5V5qRhGy8oqN9rVxdsKYwCcxPriLYVv40IqWz6HTmSi+V45Wzo4I9vc2OqY+3Tu4GVpSlc+bGtH3UZWXl+Ps2bPIy8vDiBEjDLcpcJZPshc1hdSj/UM6vR5f76/5f5p3SsqgUCiM+o+USgUmDOuGCcO6NagWvV7AzpQrSLtehG5PuGJSWHdcvnHL0EfV1IgKqsuXL2PmzJlwcnJCbm4uRowYgZMnTyI+Ph6ffPKJhUsksrzaWlKP9g+1ca65f7SSuS7LfkwtxsGzWYbPfe0FHyz9v2fN8t5yJGpQ8vLlyzFv3jzs27cPDg4V2da/f3/e50F2oa7LvUeDp+yhrs73Mtdl2c1841/Zm2K/VFWiWlS///47NBoNgD8Xd3B2dkZpaanlKiOygvr6pB7tH3Jp3Qw5hX+OzlC3dYZn25ZwVCnRvaP5LsvauzvhSvYDozqaMlFB1a5dO1y4cAG9evUybDt37hyeeOIJixVGZGlift0bH9wF568W4LcrFXOv5RSWoE9Xdzg6KC16T9PgHq3Rzqud0b1TTZmooJo/fz5mzJiBiIgIlJeX44svvsD27duxcuVKS9dHZBFif91Lu16E/Fv3jfY5Oigt3l+kbMANoU2BqKAKCgrChg0bEBcXh/79+yMrKwuxsbHw9fW1dH1EZteYmzkBXobZgujbE3r27ImePXtashYiixNzC8L44C61dl736ere5C/DbEFUUK1du7bWffPnzzdbMUSWVFtLKu7AFcPA4ZOXciEIQrVO9EqODsomN85OCkQFVU5OjtHz/Px8nDx5EiEhIRYpisjc6mpJfXsw3ejYQ6dv4rPFwUg5cR3aQuP518of6qHXCwwrKxMVVKtWraq27ciRI0hMTDR7QUTmVltILfvHTzibXvPcUkqlAp8tHorlG3/GhT8KodNVzIb025V8xB240uA7zalxRN3wWZPBgwcjJSXFnLUQmV1dd5zXFlJBfhWrKzk4KBH9/wLwuJvxmpaHHpm6hSxPVIsqMzPT6Pn9+/exd+9eqNVqixRFZA51Xe59d/RqteObOanw8tCu7CyXIFFBNWzYMCgUClROBtqiRQv4+Pjg/ffft2hxRKYSM3bvUd07uNZ4A+ejU7VUtrjIekQFVVpa9fl0iKSqppAqe6hH1Mafcf732uc7P5tegG8Ople70TJ8aFcoFAreJW5Dou+jIpKD2lpSKzf9UmufVFU13T+lVPIucVurNaiGDBliGIBclx9++MGc9RCZTKfTITtbi++PZyO76CF6PFlquJS7lv1fUe/Bu86lqdag+uCDD6xZB1Gj6HQ63LiRhcgN51BSWjEVy+nL+Th/tQCODko4N3PAf+8aT53yuFsL5BT9OY6Pd51LV61BtWbNGuzcuRMAsG7dOsyZM8dqRdXk2rVriIyMxO3bt+Hi4oKYmBh07NjR6BidTofo6GgcPXoUCoUC06dPR3h4uG0KJqvQ6wVsS07DgRMZyL1VfdqhylkPAMDJQYmyh3rDc5XK+O4c3nUuXbXeR5WRkWGYb2rz5s1WK6g2y5Ytw8SJE7F//35MnDgRS5curXbMnj17cOPGDSQnJ2PHjh2IjY3FzZu858Ve6fQCtvxQgK/3X64xpB5VNaQqGIcSL/ukq9YW1dChQxEaGop27dqhtLQUr776ao3Hbd261WLFVSosLMSlS5fw5ZdfAgBGjRqFlStXoqioCG5uf/7HlZSUhPDwcCiVSri5uSEkJAT79u3D1KlTLV4jWV/cgcv4I8f0yRuD/Nrz1zyZqDWoVq1ahVOnTiErKwvnz5/H+PHjrVmXEa1WC09PT6hUKgCASqWCh4cHtFqtUVBptVp4eXkZnqvV6mrjFMk+6HQ6nLvS8O/WGpPekfnVeXuCv78//P39UV5ejnHjxlmrJpu6cOGCrUuwGHub4969tVDn/ratHdC7UwsACmQVlKG9uxMG93CCUqEAUIwzZ361Sp2msrfvqzFE3Uc1fvx4/PHHH0hLSzMsk1V1n6Wp1Wrk5uZCp9NBpVJBp9MhLy+v2hAetVqN7Oxs9O7dG0D1FpYYvr6+aNasmdlql4rTp0/Dz8/P1mU0WtX7pKa96IOM3CMovCugk7oNenRyw+Ez2QAqLuvCh3aVbYvJXr6vR5WWlprUGBAVVOvXr8enn36K7t27G82GqFAorBJUbdu2hY+PD/bu3QuNRoO9e/fCx8fH6LIPAMLCwhAXF4fhw4fj9u3bSElJsUofGllG1cU+H+r0UCmB9m2dENSnLY5fuY8jv/2OrPyKPqrf0gvQ6yl3rI8cauOqyRJEBdU///lPxMXFoXv37paup1bLly9HZGQkPvvsM7Rp08awKnLVJd01Gg3Onj2L4cOHAwBmz54Nb29vm9VMjVPTuLzTAK7lPsC53wurHd/Ul5SyZ6KCqnnz5njyySctXUudOnfujLi4uGrbqy7prlKpsGLFCmuWRRZUW/Bcvn6rxu28vcB+iZqPav78+YiOjkZeXh70er3RP0SWUlvwlJYb/3fX3EmFPl3d8eLzT1mjLLIBUS2qyMhIADBq0QiCAIVCgdTUmqfMIGqsMc91wsmLN3E5sxhCHT/wPSjT4bcr+dj1w+8cPGynRAXVgQMHLF0HkRGdToctiWeRdqNY9GvYR2W/RK+UTGQtlbcgXM2qHlKtnR1RXFJe4+vYR2W/6gyqupbJqsTlssicqt4n1auLJy5k3DHarwnsbBj20u0JV2RlZ+PuwxYcAmPn6gwqDj8ha3p00rsI72ZQqVSGxRRquonz9Om7dnljJBmrM6hqWiaLyBJqm5lzwrBuXJqKTF8ui8hcagspokoMKrIphhSJwaAim2FIkVgMKrIJhhQ1RKOCir8KkikYUtRQjQqqESNGmKsOaiIYUmSKRgVVYmKiueqgJoAhRaZqVFA9OsMmUW0YUtQYtd7wKWb4DMAhNFQ/hhQ1Vq1BVbWjvLS0FMnJyfD19UW7du2QnZ2N8+fPG2bSJKqNOUOqcmpiLm/V9NS5XFalBQsW4MMPP0RoaKhhW3JyMvbt22fZ6kjWzN2Sqjo18clLuQCAzq6NLpNkQFQf1ZEjRxASEmK0bejQoTh8+LBFiqpqxYoVCAsLw5gxYxAREYHz58/XeNzx48fx9NNPQ6PRQKPRcCl3G7PE5d6j801x/qmmQ9R8VB06dMDWrVsxefJkw7avv/4aTzzxhMUKqxQYGIi3334bjo6OOHToEBYsWICUlJQaj+3cuTN27dpl8Zqobpbqk+rewc3Qkqp8DoifWI/kS1RQRUdHY86cOdi4cSM8PT2Rm5sLBwcHxMbGWro+BAUFGR736dMHOTk50Ov1UCp5U70UWbLjvLJPqmofldQXESXzEBVUPXr0wP79+3H27Fnk5eXB3d0dffr0gaOjo6XrM7J161Y8//zztYZURkYGxo0bBwcHB0ycOLHJrO4sFZb+dU+pVHBO9CZKVFA9qn///igpKUF5eTmcnZ0bVcC4ceOQnZ1d476ffvoJKpUKQMXNpXv27Kl1QdGePXvi8OHDaN26NTIzM/H666/D09MTgwYNalA9XNLdPAoKCqz2Wfa69Lm9npdJBBHS0tKEoKAgITQ0VOjTp48gCILwww8/CPPnzxfz8kZLTk4Whg4dKmRmZop+zapVq4TY2FjRxz948EA4deqU8ODBA1NKlLxTp05Z7L0fPnwo3Lx5U8jIyBDu379vsc+piSXPy5bs9bxM/TsT1dGzfPlyzJs3D/v27YODQ0UjrH///lZJ/EOHDmHVqlXYtGkT2rdvX+txeXl5EP63ptLt27dx7Ngxm67s3FTwZk6yBlGXfr///js0Gg0AQKGomK/a2dkZpaWllqvsf9566y04Ojpi3rx5hm1fffUVXF1dsXbtWnh4eOCVV15BcnIytm3bBgcHB+h0Omg0mmq3VJB5MaTIWkQvl3XhwgX06tXLsO3cuXNWuT3hl19+qXVf1eE7kyZNwqRJkyxeD1VgSJE1iQqq+fPnY8aMGYiIiEB5eTm++OILbN++HStXrrR0fSRBDCmyNlF9VEFBQdiwYQOKiorQv39/ZGVlITY2FoMHD7Z0fWRDOr2AnSlXELXpF+xMuQK9XmBIkU3U26LS6XQIDQ1FUlISli9fboWSSCoeHVunF/R4rkdrhhRZXb0tKpVKBZVKZZWOc5KWR8fSnbvClhTZhqhLv8mTJ+ONN97AiRMncOPGDWRmZhr+IftVMZbuTx09nRlSZBOiOtMrO82PHTtmtF2hUCA1NdX8VZEkjA/uAr2gx7krOejo6YxJI3oxpMgmRAVVWlqapesgCRL+1yc1sGsLtqTIpuoMqvv37+Pzzz/HlStX0LNnT8yYMQNOTk7Wqo1siL/ukZTU2UcVFRWFQ4cO4cknn8T+/fsRExNjrbrIhhhSJDV1BtXRo0exadMmLF68GBs2bMChQ4esVRfZCEOKpKjOoCopKYGHhweAiqWx7t69a5WiyDYYUiRVdfZR6XQ6/PLLL4ZZCR4+fGj0HAAGDhxo2QrJKhhSJGV1BlXbtm3x9ttvG567uLgYPVcoFDhw4IDlqiOrYEiR1NUZVAcPHrRWHWQjDCmSA66Q0IQxpEguGFRNFEOK5IRB1QQxpEhuGFRNDEOK5Mik5bKsKTIyEj/99BNcXV0BAGFhYZg5c2aNx+7cuRMbNmyAIAgIDAzEO++8w4VKq2BIkVxJPqgAYPr06fXOh56ZmYl169YhISEBLi4umDZtGr777juMHTvWOkXKAEOK5Mpumhv79+9HSEgI3NzcoFQqER4ejqSkJFuXJQk6nQ4AGFIkW7IIqi+//BKjR4/GrFmzcPXq1RqP0Wq18PLyMjz38vKCVqu1VomSVXm5B4AhRbJl80u/+pZ0X7BgAdzd3aFUKpGQkICpU6ciJSXFsNS7udnzku4XL160dQkWYa9Ln9vreZnC5kEVHx9f535PT0/D47Fjx2LVqlXIyclBu3btjI5Tq9VGgZednQ21Wt3genx9fdGsWbMGv05qHu04v3jxIvz8/GxdltmdPn2a5yUjpaWlJjUGJH/pl5uba3h89OhRKJVKo/CqFBoaipSUFBQVFUGv1yMuLg4vvPCCNUuVDP66R/bG5i2q+ixZsgSFhYVQKBRo1aoVPv/8czg4VJRddUl3b29vzJo1Cy+//DIAICAgAGPGjLFl6TbBkCJ7JPmg+uqrr2rdV3VJdwCIiIhARESEhSuSLoYU2SvJX/qROAwpsmcMKjvAkCJ7x6CSOYYUNQUMKhljSFFTwaCSKYYUNSUMKhliSFFTw6CSGYYUNUUMKhlhSFFTxaCSCYYUNWUMKhlgSFFTx6CSOIYUEYNK0hhSRBUYVBLFkCL6E4NKghhSRMYYVBLDkCKqjkElIQwpopoxqCSCIUVUOwaVBDCkiOom+amIp0yZglu3bgGo+INOT0/H7t270b17d6Pjjh8/junTp6Njx44AACcnJ8TFxVm73AZjSBHVT/JBVXXO9JSUFHzyySfVQqpS586dsWvXLitV1ngMKSJxZHXp98033+Cll16ydRlmwZAiEk82QVVQUICff/4ZGo2m1mMyMjIwbtw4hIeH17uwqS0xpIgaRiEIgmDLAupb0r1y6fYNGzbg7NmzWLduXY3H3r17F4IgoHXr1sjMzMTrr7+OqKgoDBo0SFQdpq7gSkQN19AVyW3eRyW25bNr1y4sXry41v2tWrUyPPb29kZISAh+/fVX0UFVyZJLutuyJWWvS4TzvOTFbpd0B4Bff/0VxcXFCAwMrPWYvLw8VDYOb9++jWPHjtXa6W4LvNwjMp3NW1Ri7Nq1C2PHjjVcBlaquqR7cnIytm3bBgcHB+h0Omg0GoSEhNioYmMMKaLGkUVQRUdH17i96pLukyZNwqRJk6xVkmgMKaLGk8Wln1wxpIjMg0FlIQwpIvNhUFkAQ4rIvBhUZsaQIjI/BpUZMaSILINBZSYMKSLLYVCZAUOKyLIYVI3EkCKyPAZVIzCkiKyDQWUihhSR9TCoTMCQIrIuBlUDMaSIrI9B1QAMKSLbYFCJxJAish0GlQgMKSLbYlDVgyFFZHsMqjowpIikgUFVC4YUkXQwqGrAkCKSFkkE1e7duzF69Gj06NEDW7ZsMdp3//59vPHGGxg2bBjCwsJw6NChWt9n586dGDZsGEJCQhAVFQW9Xt/gWhhSRNIjiaDy8fHBxx9/jFGjRlXbt2nTJrRs2RL/+c9/sH79erzzzju4d+9eteMyMzOxbt067NixA8nJybh+/Tq+++67BteSn5/PkCKSGEmsQtO1a1cAgFJZPTe///57vP/++wCAjh07wtfXF0eOHMELL7xgdNz+/fsREhICNzc3AEB4eLhhmS0xKtcELCsrg6enJxQKBUpLS009JUmyt/OpxPOSj7KyMgB//r2JJYmgqkt2djbatWtneK5Wq5GTk1PtOK1WCy8vL8NzLy8vaLVa0Z9TXl4OALhz5w7u3LnTiIqly16XrOd5yU95eXmDrlisElTjxo1DdnZ2jft++umnaguL2kLLli3RtWtXODo6QqFQ2LocIrskCALKy8vRsmXLBr3OKkEVHx9v8mu9vLyQlZVluKTTarV45plnqh2nVquNwjA7OxtqtVr05yiVSrRu3drkOolIHFP6fiXRmV6XsLAw7NixAwCQkZGB8+fP47nnnqt2XGhoKFJSUlBUVAS9Xo+4uLhq/VhEJE8KoaG9Whawd+9erF69Gnfu3IGjoyNatGiBzZs346mnnkJJSQkiIyORmpoKpVKJRYsWISQkBACwdu1aeHh44JVXXgEAbN++HRs3bgQABAQEYOnSpZK4rCSixpFEUBER1UXyl35ERAwqIpI8BhURSR6Diogkj0EF8w2KlrrIyEgEBgZCo9FAo9Hg888/t3VJJrt27RomTJiA0NBQTJgwARkZGbYuyWyCg4MRFhZm+J6OHj1q65JMEhMTg+DgYHTr1g1XrlwxbDfpuxNIuHz5spCeni4sWrRI+Pe//220LzY2Vnj77bcFQRCEa9euCYMGDRLu3r1rizIbbcmSJdXOT65ee+01ISEhQRAEQUhISBBee+01G1dkPkFBQcLly5dtXUajnTx5UsjOzq52PqZ8d2xRoWJQ9FNPPVXroOiIiAgAxoOiyXYKCwtx6dIlw2wbo0aNwqVLl1BUVGTjyqgqf3//aqNDTP3uGFT1EDsoWi6+/PJLjB49GrNmzcLVq1dtXY5JtFotPD09DTfzqlQqeHh4NGgQutQtXLgQo0ePxvLly+1qkLyp353kZ08wBzkMijaH+s5zwYIFcHd3h1KpREJCAqZOnYqUlBS7OX97sXXrVqjVapSVleHdd99FVFQU1qxZY+uybKpJBJU1BkVLQX3n6enpaXg8duxYrFq1Cjk5OUYtRjlQq9XIzc2FTqeDSqWCTqdDXl5egwahS1nleTg5OWHixImYOXOmjSsyH1O/O1761UPsoGg5yM3NNTw+evQolEqlUXjJRdu2beHj44O9e/cCqBgr6uPjY/ifiZyVlJSguLgYQMWUKElJSfDx8bFxVeZj6nfHsX4wfVC03EyZMgWFhYVQKBRo1aoVFi9ejD59+ti6LJNcvXoVkZGRuHPnDtq0aYOYmBg8+eSTti6r0TIzMzF37lzodDro9Xp07twZ77zzDjw8PGxdWoNFR0cjOTkZBQUFcHV1hYuLCxITE0367hhURCR5vPQjIsljUBGR5DGoiEjyGFREJHkMKiKSvCZxwyfZnz/++ANvvvkmrl+/jgULFuDSpUvw9PTEggULbF0aWQBbVCTKa6+9hv79+xtWuq3Prl27DItuiNWtWzdcv35d1LEbN27EgAEDcObMGUyePLlBn0Pyw6Ciet28eROnTp2CQqHAgQMHbF0OgIrB4l26dLF1GWQlDCqqV0JCAp5++mmMGzcOCQkJRvu0Wi3mzJmDZ599Fs888wyioqJw9epVLFu2DL/99hv69u0Lf3//Bn9mbGws5s+fj8WLF6Nv374YOXIkzp8/DwCYPHkyjh8/jqioKPTt2xfXrl0zem1NrbmqrbWysjLExMTg+eefx6BBg7B06VI8ePAAAHD8+HEEBgZi8+bNGDhwIAYPHoxvv/3W8D4PHjzA+++/j6CgIPj5+eGVV14xvPa3335DREQE/P39MWbMGBw/frzB5001Y1BRvSpnQB09ejR+/PFHFBQUAAB0Oh1mzJgBLy8vHDx4EEeOHMGIESPQuXNnrFixAn369MGZM2dw6tQpkz734MGDGDlyJE6dOoXg4GCsXLkSAPCvf/0L/v7+WLp0Kc6cOYNOnTo16H0/+OADXLt2DQkJCUhOTkZeXh4+/fRTw/6CggIUFxfjyJEjhtkL/vvf/wKomLXy4sWL2L59O06cOIFFixZBqVQiNzcXM2bMwMyZM3HixAksWbIE8+bN4xxZZsKgojqdOnUK2dnZeOGFF+Dr6wtvb2/DgNJz584hLy8PixcvhrOzM5o1a2ZS66k2fn5+GDJkCFQqFTQaDdLS0hr9noIgIC4uDm+//TZcXFzQqlUrzJgxA4mJiYZjHBwcMHv2bDg6OmLIkCFwdnbGtWvXoNfr8e233+Jvf/ubYU6lfv36wcnJCbt370ZgYCCGDBkCpVKJgIAA+Pr64vDhw42umfirH9UjISEBAQEBhtHto0aNQnx8PKZMmQKtVgsvLy84OFjmP6PHHnvM8Lh58+YoLS3Fw4cPG/V5RUVFuH//Pl588UXDNkEQoNfrDc9dXFyMPqNFixYoKSnBrVu3UFpaCm9v72rvm52djX379hnNqf/w4UPJTgkkNwwqqtWDBw/w/fffQ6/XIyAgAEBF/86dO3eQlpYGtVoNrVZbY3goFApblAygIlgq+40AID8/3/DY1dUVzZs3R2JiYoOnuHF1dUWzZs2QmZmJ7t27G+1Tq9XQaDSIjo5uXPFUI176Ua0qZ/9MTExEQkICEhISkJSUBH9/fyQkJKB3795wd3fHhx9+iJKSEpSWluL06dMAKuYdys3NFX07gzl1794d6enpSE1NRWlpKWJjYw37lEolwsPD8d5776GwsBBAxTxdYlZ6USqVeOmll7Bq1SrD5G9nzpxBWVkZxowZg0OHDuHo0aPQ6XQoLS3F8ePHZT1ttZQwqKhW8fHxePHFF+Hl5QV3d3fDP6+++ir27NkDQRCwfv16XL9+HUFBQQgMDMT3338PAHj22Wfx1FNPYfDgwYbLn/Xr12Pq1KkWr7tTp06YPXs2pkyZguHDh8PPz89o/6JFi9ChQwe8/PLL6NevH6ZMmVLtl8PaLFmyBF27dsX48eMxYMAArFmzBnq9Hmq1Gp999hm++OILDBw4EEOGDMGmTZuMLinJdJyPiogkjy0qIpI8BhURSR6Diogkj0FFRJLHoCIiyWNQEZHkMaiISPIYVEQkeQwqIpK8/w8tOXzv0KggLAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "\n",
    "sns.set_theme()\n",
    "sns.set_style(\"whitegrid\")\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "plt.plot(x, x, color=\"grey\", alpha=0.25, zorder=0)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 20, label = 'A', linewidths=0)\n",
    "plt.ticklabel_format(style=\"sci\", scilimits=(-4, 4))\n",
    "plt.axis('square')\n",
    "plt.xlabel('Act. Influence')\n",
    "plt.ylabel('Pred. Influence')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fa643c37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.9738260627952212, pvalue=5.148011006700848e-39)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "60c17971",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.9738260627952212, pvalue=5.148011006700848e-39)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "67e5288b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = pd.DataFrame([time_infl[i] / time_retrain[i] for i in range(len(time_infl))])\n",
    "data1.to_csv('running time/pubmed_running_time.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a08755b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
