{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "45fa3296",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "import dgl\n",
    "import numpy_ml\n",
    "import scipy\n",
    "import cupy as cp\n",
    "import cupyx\n",
    "\n",
    "from dgl import function as fn\n",
    "from dgl.base import DGLError\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from scipy.special import softmax, log_softmax\n",
    "from numpy.linalg import inv, pinv\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.metrics import log_loss\n",
    "from dataset import load_graph_dataset\n",
    "\n",
    "from model_node_influence import NodeInfluenceSGC\n",
    "\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork, fast_hess, fast_hess_cuda, fast_get_inv_hvp_cuda\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from scipy.linalg import cho_solve, cho_factor\n",
    "from model_edge_influence import EdgeInfluenceSGC, generate_remove_index_train\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearnex import patch_sklearn, config_context\n",
    "from dataset_wikics_amazon import load_wikics, load_amazon\n",
    "patch_sklearn()\n",
    "\n",
    "from dgl.data import AmazonCoBuyComputerDataset, AmazonCoBuyPhotoDataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9a18fb1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_amazon(dataname='computer', seed = 1)\n",
    "# l2_term = 0.05\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "27581ea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_amazon(dataname='photo', seed = 1)\n",
    "l2_term = 0.05\n",
    "num_layer = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b9a69142",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_wikics()\n",
    "# mask_idx = 0\n",
    "# train_mask = train_mask[mask_idx]\n",
    "# val_mask = val_mask[mask_idx]\n",
    "# l2_term = 0.1\n",
    "# num_layer = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52315f52",
   "metadata": {},
   "source": [
    "##### 1, load data, convert to one hot encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c5afa228",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(num_layer):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1093ab01",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float32)\n",
    "train_y = labels[train_mask].numpy().astype(np.float32)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float32)\n",
    "val_y = labels[val_mask].numpy().astype(np.float32)\n",
    "\n",
    "# val_x = feat0[test_mask].numpy().astype(np.float32)\n",
    "# val_y = labels[test_mask].numpy().astype(np.float32)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f56d207f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert to one-hot labels\n",
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_val = enc.transform(val_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41c20b6b",
   "metadata": {},
   "source": [
    "##### 2, Remove a single node, remove the propogated node feature as well as it connected edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "126bcdfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"retrieve all edges connected to traing nodes\"\"\"\n",
    "def generate_remove_index_train_all(from_indexes, to_indexes, train_mask):\n",
    "    train_index = torch.where(train_mask == 1)[0]\n",
    "    remove_from_list = []\n",
    "    remove_to_list = []\n",
    "    for i in tqdm(range(len(train_index))):\n",
    "        f_index = train_index[i]\n",
    "        to_index_list = torch.where(from_indexes == f_index)[0]\n",
    "        for to_index_e in to_index_list:\n",
    "            j = to_index_e\n",
    "            t_index = to_indexes[j]\n",
    "\n",
    "            remove_from_list.append(f_index)\n",
    "            remove_to_list.append(t_index)\n",
    "\n",
    "    return torch.tensor(remove_from_list), torch.tensor(remove_to_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eb6654c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 160/160 [00:00<00:00, 5395.47it/s]\n"
     ]
    }
   ],
   "source": [
    "from_indexes, to_indexes = graph.edges()\n",
    "\n",
    "f_l, t_l = generate_remove_index_train_all(from_indexes, to_indexes, train_mask)\n",
    "\n",
    "acctual_influence_node_features = []\n",
    "acctual_influence_edges = []\n",
    "\n",
    "predict_influence_node_features = []\n",
    "predict_influence_edges = []\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9291aec",
   "metadata": {},
   "source": [
    "##### 2.1 train the original data, calculate the hessian matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "27d87821",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 159/159 [00:00<00:00, 534.78it/s]\n"
     ]
    }
   ],
   "source": [
    "lr_origin = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "\n",
    "lr_origin.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "\n",
    "logits_val_y_origin = val_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "logits_train_y_origin = train_x @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "ori_val_loss, ave_ori_val_loss = lr_origin.log_loss(logits_val_y_origin, one_hot_labels_val, l2_reg=True)\n",
    "\n",
    "# numpy_theoritic_loss = log_loss(val_y, softmax(logits_val_y_origin, axis=1))\n",
    "# # set l2_reg to False, verify the correctness of calculations\n",
    "# assert np.allclose(numpy_theoritic_loss, ave_ori_val_loss)\n",
    "\n",
    "val_loss_total_grad_orig, val_loss_indiv_grad_orig = lr_origin.grad(val_x, \n",
    "                                                                    logits_val_y_origin,\n",
    "                                                                    one_hot_labels_val, l2_reg = True)\n",
    "\n",
    "hess = lr_origin.hess_cuda(train_x, logits_train_y_origin, l2_reg = True)\n",
    "\n",
    "loss_grad_hvp = fast_get_inv_hvp_cuda(hess, val_loss_total_grad_orig.T, cholskey=True)\n",
    "\n",
    "loss_grad_hvp = cp.asnumpy(loss_grad_hvp)\n",
    "del hess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1a2e65d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "time_infl = []\n",
    "time_retrain = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ba180b45",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 160/160 [00:32<00:00,  4.87it/s]\n"
     ]
    }
   ],
   "source": [
    "acctual_influence_1 = []\n",
    "acctual_influence_2 = []\n",
    "\n",
    "predict_influence_1 = []\n",
    "predict_influence_2 = []\n",
    "\n",
    "for k in tqdm(range(len(train_node_idx))):\n",
    "    \n",
    "    node_id = train_node_idx.numpy()[k]\n",
    "    nis = NodeInfluenceSGC(graph = graph, feature=feat, node_index=node_id)\n",
    "    \n",
    "    # 2, remove the edges, calculate the perturbated feature\n",
    "    nis.remove_edges_sgc()\n",
    "    feat_removed1 = nis.calculate_modified_features()\n",
    "\n",
    "    extra_index = torch.unique(torch.where(feat0 != feat_removed1)[0])\n",
    "    \n",
    "    \n",
    "    extra_index_train = torch.tensor(\n",
    "        [extra_index[i] for i in range(len(extra_index)) if extra_index[i] in train_node_idx]).numpy()\n",
    "    \n",
    "    extra_index_train_in_train = [\n",
    "        np.where(train_node_idx.numpy() == extra_index_train[j])[0][0] for j in range(len(extra_index_train))]\n",
    "    \n",
    "    # 1, we need to remove the changed node feature from the perturbated feature, \n",
    "    # let it not added to the original feature\n",
    "    \n",
    "    \n",
    "    \"\"\"modified node features\"\"\"\n",
    "    extra_index_train_remove_node = extra_index_train.copy()\n",
    "    relative_node_id = np.where(extra_index_train_remove_node == node_id)[0]\n",
    "    extra_index_train_remove_node = np.delete(extra_index_train_remove_node, relative_node_id)\n",
    "    feat_to_be_added = feat_removed1[extra_index_train_remove_node].numpy()\n",
    "    \n",
    "    \"\"\"index corresponding to modified node features\"\"\"\n",
    "    perturb_index = extra_index_train_in_train\n",
    "    added_index = perturb_index.copy()\n",
    "    \n",
    "    if added_index == []:\n",
    "        continue\n",
    "        \n",
    "    \n",
    "    added_index.remove(k)\n",
    "\n",
    "    \n",
    "    \n",
    "    train_x_new = feat_to_be_added\n",
    "    train_y_new = train_y[added_index]\n",
    "    \n",
    "    train_x_orig = np.concatenate([train_x, train_x_new])\n",
    "    train_y_orig = np.concatenate([train_y, train_y_new])\n",
    "    \n",
    "    one_hot_labels_train_0 = enc.transform(train_y_orig.reshape(-1, 1)).toarray()\n",
    "    logits_train_y_origin_0 = train_x_orig @ lr_origin.model.coef_.T + lr_origin.model.intercept_\n",
    "\n",
    "    train_total_grad_orig, train_indiv_grad_orig = lr_origin.grad(train_x_orig, \n",
    "                                            logits_train_y_origin_0, \n",
    "                                            one_hot_labels_train_0, l2_reg = True)\n",
    "    \n",
    "\n",
    "    start_time = time.time()\n",
    "    pred_infl = train_indiv_grad_orig.dot(loss_grad_hvp)\n",
    "    time_infl.append(time.time() - start_time)\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    weight_3 = np.ones(len(train_x_orig))\n",
    "    weight_3[perturb_index] = 0 # 1...0...11\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    lr_new_2 = SimplifiedGraphNeuralNetwork(l2_reg=l2_term, fit_intercept=True)\n",
    "    train_x_delete_2 = train_x_orig[weight_3 == 1]\n",
    "    train_y_delete_2 = train_y_orig[weight_3 == 1]\n",
    "    \n",
    "    start_time = time.time()\n",
    "    lr_new_2.fit(train_x_delete_2, train_y_delete_2)\n",
    "    time_retrain.append(time.time() - start_time)\n",
    "    \n",
    "    logits_val_y_new_2 = val_x @ lr_new_2.model.coef_.T + lr_new_2.model.intercept_\n",
    "    new_ori_val_loss_2, _ = lr_new_2.log_loss(logits_val_y_new_2, one_hot_labels_val, l2_reg = True)\n",
    "    \n",
    "    predict_influence_1.append(np.sum(pred_infl[perturb_index]) - np.sum(pred_infl[len(train_x):]))\n",
    "    acctual_influence_1.append(new_ori_val_loss_2 - ori_val_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3cb241e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(89.3351881005478, 88.69704634731524)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_ori_val_loss_2, ori_val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "85f57408",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAAEQCAYAAAB1FFtSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAicUlEQVR4nO3de1xUdf4/8NcwXBTUBe+gpFbiJVQMyExWBNs0Q0nXykr9tWsrtbGS+ygle+QFSqS0Mq1oM7VV2h5aQasWamJqm5LwxQsqaorGOtwEvA06XOb8/mAZmZHLGThnzpmZ1/Px6PFobmfeM37Oi898zud8jkYQBAFERBJyUboAInI8DBYikhyDhYgkx2AhIskxWIhIcgwWIpKcq9IFNGYwGLB8+XIcPHgQHh4eCAoKQmJiotJlEZGVVBUs77zzDjw8PLBz505oNBpcvnxZ6ZKIqA00apkgp9frER4ejn379sHLy0vpcoioHVTTYyksLIS3tzfWrl2LrKwseHl5IS4uDiEhIa2+1mg0Qq/Xw83NDRqNxgbVEjkXQRBQU1MDLy8vuLi0PjSrmmCpra1FYWEhhg4dioULF+Lo0aN44YUXsHv3bnTq1KnF1+r1epw5c8ZGlRI5r4CAAHTu3LnV56kmWPz8/ODq6oqoqCgAwIgRI+Dj44OCggIMGzasxde6ubkBqP/Q7u7ustcqRl5eHgIDA5Uuw0Rt9QDqq4n13KmqqgoVFRXQarUoLy837WutUU2wdO3aFaNGjcJ//vMfhIWFoaCgAOXl5ejXr1+rr234+ePu7g4PDw+5SxVNTbUA6qsHUF9NrOc2vV6Pq1evwtPTE97e3igvLxc91KCaYAGAZcuWYdGiRUhOToarqyvefvttdOnSRemyiJyOXq9HWVkZPDw80KtXL9TU1Fj1elUFi7+/PzZt2qR0GUROzTJUxAzWWuLMWyIykSJUAAYLEf2PVKECMFiICNKGCsBgIXJ6UocKwGAhcmpyhArAYCFyWnKFCsBgIXJKcoYKwGAhcjpyhwrAYCFyKrYIFYDBQuQ0bBUqAIOFyCnYMlQABguRw7N1qAAMFiKHpkSoAAwWIoelVKgADBYih6RkqAAMFiKHo3SoAAwWIoeihlABGCxEDkMtoQIwWIgcgppCBWCwENk9tYUKwGAhsmtqDBWAwUJkt9QaKgCDhcguqTlUAAYLkd1Re6gADBYiu2IPoQIwWIjshr2ECsBgIbIL9hQqgEqDZe3atRg0aBDOnDmjdClEirO3UAFUGCwnTpzAkSNH4Ofnp3QpRIqzx1ABVBYs1dXVSEhIwJIlS6DRaJQuh0hx9hgqgMqCZfXq1ZgyZQr8/f2VLoVIUXq9HgDsMlQAwFXpAhrk5ubi+PHjeOWVV9q8jby8PAkrar+cnBylSzCjtnoA9dWktnp0Oh10Op3SZVhNNcFy+PBhnD9/HuPHjwcAFBcXY86cOUhKSkJYWJiobQQGBsLDw0POMkXLyclBcHCw0mWYqK0eQH01qaGexmMqOp1O8XoaGAwGq/5wqyZY5s6di7lz55puR0ZGIiUlBQEBAQpWRWQ7lgO19thTaWBfP9yIHJS9Hv1pjmp6LJYyMzOVLoHIJhwtVAD2WIgU5YihAjBYiBTjqKECMFiIFOHIoQIwWIhsztFDBWCwENmUM4QKwGAhshlnCRWAwUJkE84UKgCDhUh2zhYqAIOFSFbOGCoAg4VINs4aKgCDhUgWzhwqAIOFSHLOHioAg4VIUgyVes75qYlkwFC5zXk/OZGEGCrmnPvTE0mAoXInfgNE7cBQaRq/BaI2Yqg0j98EURswVFrGb4PISgyV1vEbIbICQ0UcfitEIjFUxOM3QyQCQ8U6/HaIWsFQsR6/IaIWMFTaRtS3JAgCtmzZgtmzZ2Py5MkA6i/i/t1338laHJGSGCptJ+qbWr16Nb766is89dRTKCoqAgD07t0b69atk7U4IqUwVNpH1LeVlpaGlJQUPPbYY9BoNACAvn37orCwUNbiiJTAUGk/UReFr6urg5eXFwCYgkWv18PT01OyQiorK7FgwQL89ttvcHd3R79+/ZCQkICuXbtK9h5ErWGoSEPUtxYeHo6kpCRUV1cDqB9zWb16NSIiIiQrRKPR4Pnnn8fOnTuxbds2+Pv7Y+XKlZJtn0gMhoo0RH1zr732GkpLSxEcHIzr169j5MiR0Ol0eOWVVyQrxNvbG6NGjTLdDgoKgk6nk2z7RC3R6/UAwFCRiEYQBEHsk8vLy3Hp0iX4+vqiR48eshVlNBrx5z//GZGRkZg9e3arzzcYDMjLy5OtHiKqFxgYCA8Pj1afJ2qM5aeffkKfPn0wYMAAdOvWDQBw/vx5FBUVYcyYMe2rtAmJiYnw9PTEzJkzrXqd2A9tCzk5OQgODla6DBO11QOoo6bGYyo6nU7xehpTw/fTwNo/3qL6ewkJCabB2wZeXl5ISEiwrjoRkpOTcfHiRbz//vvsjpKsLAdqSTqieizl5eXo2bOn2X09e/ZEWVmZpMW89957yMvLwz/+8Q+4u7tLum2ixnj0R16ivk1/f38cPHjQ7L6srCz07dtXskLOnj2LlJQUlJaWYsaMGYiOjsZLL70k2faJGjBU5CeqxxIbG4u//e1vmD59Ovz9/VFYWIhvvvkGy5cvl6yQgQMH4vTp05Jtj6gpDBXbEPWtPvzww1i/fj2qqqqwb98+VFVVYd26dXj44Yflro9IMgwV2xHVYwGA4cOHY/jw4XLWQiQbhoptiQqW6upqpKWl4dSpU6iqqjJ77O2335alMCKpMFRsT1SwxMfHIz8/HxEREejevbvcNRFJhqGiDFHBcuDAAezZswddunSRux4iyTBUlCPqm/b19TWdgEhkDxgqyhLVY3n88cfx17/+FbNnzzZN6W8wevRoWQojaiuGivJEBcvmzZsBAO+++67Z/RqNBnv27JG+KqI2Yqiog6hgyczMlLsOonZjqKiH6G++pqYG2dnZpgW0q6qq7jj0TKQUhoq6iOqxnD59Gi+++CLc3d1RUlKCSZMm4fDhw0hLS8P7778vc4lELWOoqI+of4GlS5di3rx5yMjIgKtrfRaFhoYiJydH1uKIWuNooVJnFLDlhzNI+OwQ9p+4BqNR9DpsqiKqx/Lrr78iOjoawO3FtD09PWEwGOSrjKgVjhYqAPB15lls+v6U6XYfv7N48uEABStqG1H/En369Llj9ahjx47hrrvukqUootY4YqgAQP7FihZv2wtR/xpxcXGIiYnBBx98gJqaGnzyySeIi4vDyy+/LHN5RHdy1FABgMH9urZ4216I+ikUERGBTz/9FFu3bkVoaCguXbqENWvWIDAwUO76iMw4cqgAwPTIgQDqeyqd3W6abtsb0csm3HfffbjvvvvkrIWoRY4eKgDg4qIxjank5OTAxUWjcEVtIypYVq9e3exjcXFxkhVD1BxnCBVHIipYiouLzW6XlZXh8OHDXEGObKK9oVJnFPB15lnkX6zA4H5dMT1yoN32BOyFqGBJSkq64779+/djx44dkhdE1JgUPZXGh3APnywBgFYP4UoVRs4aaqLHWCyFhYVh/vz5UtZCZEaqnz9tOYTbljCSczv2RlSwFBYWmt2+efMmtm/fDl9fX1mKIhIbKmJ6BAF3+Zh2agAYdJdPq+8v1XwSR5mXYi1RwfKHP/wBGo0GDZd57tixI4YMGYIVK1bIWhw5J2t6KmJ6BG354TG4X1ezMGrrfBKptmNvRAVLfn6+3HUQAbgdKm5u7tifdw2nv7vY4tjEqQsWPYILd/YITv9W2eLtpjSeT9Lw/m0h1XbsTZvHWIikdv36DaR+fwLZZ67gpqEO16pqALQ8NlFbZzS7XWNxG2hbr6HxfJKWtPZTTOx2HE2zwRIeHm464bAlP/74o5T1kBPS36rF82/uxI2btc0+59SFcmz54cwdO7Cr1ryNWt4GgKnj7sXxc5dRoLuKAX6/w7Rx90pWu7MOzram2WB55513bFkHOSGjUcAXu/Lxr52tX1q3ptbY5A48pH83ZJ8qNT1vSP9ud7w27cdfceRMGQDgyJkyfLX3LFw0GkkOATvr4Gxrmg2WlStXYsuWLQCAtWvXIjY2VvZiCgoKEB8fjytXrsDb2xvJycno37+/7O9Ltnezug7JX1+CoUbc88sqb5rdPnWhHIC4MQzLnf3HnP/iUtkNAO3vZTjr4Gxrmh1uv3Dhgmm9lfXr19ukmCVLluCZZ57Bzp078cwzz2Dx4sU2eV+yrdxjJXjyte2iQwUASivMl0Gtras/QtkwhvH6n0YBAN7ckIUtP5wxWyCptZ29Pb2M6ZEDMevRIQgd2guzHh3iNIOzrWm2xzJ+/HhMmDABffr0gcFgwLPPPtvk81JTUyUppLy8HCdPnsSGDRsAAFFRUUhMTERFRQW6duVfAUdxqfgGFn9+yOrX1VqspOamNf+b2NJYh2WvRhAEbM64faSzPb0MZx2cbU2zwZKUlITs7GxcunQJx48fx/Tp02UtpKioCL169YJWqwUAaLVa9OzZE0VFRQwWB1F5zYAX3pHmcjGD+5u3iZbGOix3fqNRgMZijIWk1eLh5pCQEISEhKCmpgZTp061VU1tZrnKndLUtiawkvXU1hrx5hadVa/p1sUVv/PU4nzx7SVQu3V2xYi7PTHA+5rZ5+nsZj4G09ntZouf9x4f4B4fNwDXkZv7f6b7+W8mDVHzWKZPn47z588jPz//jkt+SNWT8fX1RUlJCerq6qDValFXV4fS0lKrThsIDAyEh4eHJPW0V05ODoKDg5Uuw0SJehrmeJy6UI5T58tFvaaLpxs6e3kgIrgvnhhf38v4SsRJfCNHCujj176T/fhv1jyDwWDVH25RwZKSkoIPP/wQgwcPRocOHUz3azQayYKlW7duGDJkCLZv347o6Ghs374dQ4YM4c8gFWpuUpjl/ZZjGWIM6t8Vi+c8aHafmDEMjnWoi6hg+fzzz7F161YMHjxY1mKWLl2K+Ph4fPTRR+jSpQuSk5NlfT9qm+YGSi3v79OjU6vbGnZPVxw/d3s8RK2Ha511+YO2EhUsHTp0wN133y13LbjnnnuwdetW2d+H2qZh5/r3gXNm9zcMlFoOoJaU65vdlkYDjAvsjHkzx+CbH39V/UAqZ9haR1SwxMXF4c0330RsbCy6d+9u9hiXCHQelte8adDQy7CcLGZ5iLhBp46u+HzxRBw/fgSuri52sYNyhq11RAVLfHw8AJj1JgSh/pDdqVN3NjRyTA2zXRt09nTFPX19kH+hAlt+OINp4+6FUajv1dyqrmtyGyMGdkfC3Ifs7mcEZ9haR1Sw7NkjzdwDsm8Ns10bdOrobjoH5/Cp+p3ORaO5I1Q83LTo7t3RdKTH3kIFcN7lD9pKVLD06dNH7jrIDljOdq0ymJ+NnH+xAnW1d/ZUht3bHUuef/CO++0JjzpZp8VgaemyHw14+Q/nMbh/V1PPBAAG+P3O1GMBgLt9O0Gv1+P/LF5nuWYKOb4Wg8Xysh/k3Cx/Dkwbd6/piM7dvp3w0BAveHh0Re6v13Hp8u0jQk2tkUKOrcVgaeqyH+S8Gn4ONBx2Xv75LxjcryvmPxmI8vLLpjVqI0NvmB09amqNFHJsXJqSzIiZCGY5p6OyshJTwu4yLXzNgU5isJAZMRPBLOdw/FZ6y2w1fQ50Eme3kRkxE8Es53AMD+jNiZJkhj0WMiNmItijo/xQWVmJ30pvYXhAb9NZyKQOajivqV3BUlxcjN69e0tVC6nA9MiBMAoCfsz5L4D6GdZGo2BqmHq9HuXll83GVEhd1HBeU7taxaRJk6Sqg1TCxUUDF40Gl8pu4FLZDWzOyMdXmWcBSHctZZKXGs5ralfL2LFjh1R1kIo01TAZKvbD8uerEuc1teunEC8K75gsx1nu9u3EULEjajjc32ywiJnOD3BKvyNq3DBvz6hlqNgLNRzubzZYGk/nNxgM2LVrFwIDA9GnTx/odDocP34cjzzyiE2KJPk1dSTh5s0q9lSoTVq8/EeD+fPnY9WqVZgwYYLpvl27diEjI0Pe6shmLI8kVFdXI2xoJ4YKtYmo1rJ//348/PDDZveNHz8e+/btk6UoZ1NnFLDlhzNI+OzQHVfxsxXLAdsT5+TpqTT+rPtPXFPks5L8RLWYfv363XHFwy+++AJ33XWXLEU5m4bewuGTJdj0/SnT4V1bsjxycG/fLrL0VBp/1syj1xT5rCQ/UUeFGta7XbduHXr16oWSkhK4urpizZo1ctfnFNQw72B65EBUV1fjxLky3Nu3C/7f5CBZfv6o4bOS/EQFy9ChQ7Fz504cPXoUpaWl6NGjB4KCguDm5iZ3fU5BDeup3rxZhbChnTB+ZDdZx1TU8FlJfm2axxIaGoqqqirU1NTA09NT6pqcjtLzDmw5+a3xZ+3sdpNLKjgoUcFy+vRpvPjii3B3d0dJSQkmTZqEw4cPIy0tDe+//77MJTo+Jecd2HpGbePPmpOTY5cLa1PrRLWipUuXYt68ecjIyICra30WhYaG2u0Fq6kep+mTXES1pF9//RXR0dEA6q/XDACenp4wGAzyVUayYqiQnERf/iMvLw/Dhg0z3Xfs2DEebraBpmbECkC71ttgqJDcRF9iNSYmBjNmzEBNTQ0++eQTfPnll0hMTJSkiGXLluHgwYNwd3eHp6cnXn/9dbMQc2ZNra0BoM3rbTBUyBZEtaqIiAh8+umnqKioQGhoKC5duoQ1a9YgLCxMkiLGjh2Lbdu24d///jdiYmIwf/58SbbrCJqa99HWuSAMFbKVVnssdXV1mDBhAr777jssXbpUliIiIiJM/x8UFITi4mIYjUY2fDQ/76Mtc0EYKmQrrQaLVquFVquFwWCAu7u77AWlpqZi3LhxbPj/09IcF7HzXvT6+ouHMVTIVjSCILR6FlhqaioyMzMRExOD3r17m44MAYC/v3+rbzJ16lTodLomH/v555+h1WoB1K9I98EHHyA1NRXdu3cX+xlgMBiQl5cn+vlE1DaBgYHw8PBo9XmigmXw4MFNv1ijwalTp5p8zFq7d+9GcnIyNm7ciL59+1r12oZgEfuhbSEnJwfBwcGK1tB4TEWn0ylejyU1fEeNsZ7mWbuPiToqlJ+f3+7CWrJ3714kJSVhw4YNVocKNc1yoLa5HiORHFoMlps3b+Ljjz/GmTNncN999yEmJkaWcZbXXnsNbm5umDdvnum+jRs3wsfHR/L3cgY8+kNKazFYEhISkJeXh9///vfYuXMnrly5gjfeeEPyIg4dOiT5Np0VQ4XUoMVWd+DAAXz22WdYsGABPv30U+zdu9dWdVEbMFRILVpseVVVVejZsyeA+kt93LhxwyZFkfUYKqQmLf4Uqqurw6FDh9Bw4Ki2ttbsNgCMHj1a3gqpVQwVUpsWg6Vbt25YtGiR6ba3t7fZbY1Ggz179shXHbWKoUJq1GKwZGZm2qoOagOGCqkVW6KdYqiQmrE12iGGCqkdW6SdYaiQPWCrtCMMFbIXbJl2gqFC9oSt0w4wVMjesIWqHEOF7BFbqYoxVMhesaWqFEOF7BlbqwoxVMjescWqDEOFHAFbrYowVMhRsOWqBEOFHAlbrwowVMjRsAUrjKFCjoitWEEMFXJUbMkKYaiQI2NrVgBDhRwdW7SNMVTIGbBV2xBDhZwFW7aNMFTImbB12wBDhZyNqlp4VlYWhgwZgs2bNytdimQYKuSMVNPKb9y4gZUrV2Ls2LFKlyIZhgo5K9W09BUrVmDOnDnw8fFRuhTJMFTIWamite/btw/Xrl3DxIkTlS5FEnq9HgAYKuS0WrzEqlSmTp0KnU7X5GMZGRlYtWoVNmzY0O73ycvLa/c2pKTT6Zr93ErIyclRuoQ7qK0m1iMNmwRLWlpas49lZ2ejrKwMTzzxBACgsrISe/fuxZUrVxAbG2vV+wQGBsLDw6NdtbZH4zEVnU6H4OBgxWqxlJOTo6p6APXVxHqaZzAYrPrDbZNgaUlISAgOHjxouh0fH4/AwEDMnDlTwaqsZzlQq6aeCpGt8ce/BHj0h8ic4j0WSytWrFC6BKswVIjuxL2gHRgqRE3jntBGDBWi5nFvaAOGClHLuEdYiaFC1DruFVZgqBCJwz1DJIYKkXjcO0RgqBBZh3tIKxgqRNbjXtIChgpR23BPaQZDhajtuLc0gaFC1D7cYywwVIjaj3tNIwwVImlwz/kfhgqRdLj3gKFCJDWn34MYKkTSc+q9iKFCJA+n3ZMYKkTyccq9iaFCJC+n26MYKkTyc6q9iqFCZBtOs2cxVIhsxyn2LoYKkW05/B7GUCGyPYfeyxgqRMpw2D2NoUKkHIfc2xgqRMpyuD2OoUKkPNVcFH7Tpk1ITU2Fm5sbtFot0tPTrd5GVVUVrl69ylAhUpgqgmXXrl3IyMjAV199hU6dOqGsrKxN26moqICnpydDhUhhqgiW9evXIy4uDp06dQIA9OjRw6rXC4IAANBqtfD29kZNTY3kNbaFwWBQugQzaqsHUF9NrKdp1dXVAG7va63RCGKfKaPQ0FDMmTMHP/74I6qrqzFjxgw8+eSTol9//fp1nDlzRsYKiQgAAgIC0Llz51afZ5Mey9SpU6HT6Zp87Oeff0ZdXR2KiorwxRdfoLKyEk8//TQGDBiA0NBQUdv38vJCQEAA3NzcoNFopCydiFDfU6mpqYGXl5eo59skWNLS0lp83M/PD1FRUXBxcUG3bt3w0EMP4dixY6KDxcXFRVSKElHbdejQQfRzVTHCGRUVhQMHDgCoP7KTk5ODwYMHK1wVEbWVKsZYbt26hTfeeAMnT54EAERHR2Pu3LkKV0VEbaWKYCEix6KKn0JE5FgYLEQkOQYLEUmOwUJEknO4YNm0aRMmTpyIyZMn4/HHH1e6HABAVlYWhgwZgs2bNytax7JlyzBx4kRMmTIFM2bMwPHjxxWpo6CgAE899RQmTJiAp556ChcuXFCkDgCorKzEX/7yF0yYMAGTJ09GbGwsKioqFKunsbVr12LQoEGqmFVuMBiwZMkSPPLII5g8eTLeeOONll8gOJCdO3cKzzzzjHD9+nVBEAShtLRU4YoE4fr168L06dOFuXPnCps2bVK0lszMTKG6utr0/+PHj1ekjlmzZgnp6emCIAhCenq6MGvWLEXqEARBqKysFA4dOmS6vWLFCuG1115TrJ4GeXl5wpw5c4Rx48YJp0+fVrocITExUXjrrbcEo9EoCIIglJWVtfh8h+qxrF+/HrGxsW0+mVEOK1aswJw5c+Dj46N0KYiIiICbmxsAICgoCMXFxTAajTatoby8HCdPnkRUVBSA+smRJ0+eVKyX4O3tjVGjRpluBwUFNXv6ia1UV1cjISEBS5YsUcUpKnq9Hunp6YiLizPV07179xZf41DBcu7cORw9ehQzZszAtGnTsGXLFkXr2bdvH65du4aJEycqWkdTUlNTMW7cOJsvL1FUVIRevXpBq9UCqD8jvWfPnigqKrJpHU0xGo3417/+hcjISEXrWL16NaZMmQJ/f39F62hQWFgIb29vrF27FtOmTcOsWbOQnZ3d4mtUsWyCWHKfzChlPRkZGVi1ahU2bNggy3tbW8/PP/9s2pl37NiBbdu2ITU11Wa12YPExER4enpi5syZitWQm5uL48eP45VXXlGsBku1tbUoLCzE0KFDsXDhQhw9ehQvvPACdu/ebfp1YMmugkXukxmlrCc7OxtlZWV44oknANQPEu7duxdXrlxBbGyszetpsHv3brz33nvYuHFjq91ZOfj6+qKkpAR1dXXQarWoq6tDaWkpfH19bV5LY8nJybh48SJSUlIUXSTs8OHDOH/+PMaPHw8AKC4uxpw5c5CUlISwsDBFavLz84Orq6vp5+uIESPg4+ODgoICDBs2rOkX2WLgx1Y+/vhjYdWqVYIgCIJerxeioqKEn376SeGq6i1cuFAVg7cRERHChQsXFK1j5syZZoO3M2fOVLSed999V5g5c6ZQVVWlaB1NiYiIUMXg7Z/+9CfhwIEDgiAIwvnz54UHHnhAuHr1arPPd6hzhdR8MmN8fDwCAwMV7WY/+OCDcHNzQ9euXU33bdy40eYDy+fOnUN8fDyuXbuGLl26IDk5GXfffbdNa2hw9uxZREVFoX///qZlAfr27YsPP/xQkXosRUZGIiUlBQEBAYrWUVhYiEWLFuHKlStwdXXFyy+/jPDw8Gaf71DBQkTq4FBHhYhIHRgsRCQ5BgsRSY7BQkSSY7AQkeTsaoIcOYbz58/j73//Oy5evIj58+fj5MmT6NWrF+bPn690aSQR9licwKxZsxAaGmq6ml1rvvnmGzz99NNWvcegQYNw8eJFUc9dt24dHnjgAeTm5mL27NlWvQ/ZBwaLg/vvf/+L7OxsaDQa7NmzR+lyAAA6nQ4DBw5UugySEYPFwaWnp2PEiBGYOnUq0tPTzR4rKipCbGwsHnzwQYwaNQoJCQk4d+4clixZgiNHjmDkyJEICQmx+j3XrFmDuLg4LFiwACNHjsRjjz1mWlRq9uzZyMrKQkJCAkaOHImCggKz1zbVW2rcG6qurkZycjLGjRuHhx56CIsXL8atW7cA1C+oNXbsWKxfvx6jR49GWFgYvv76a9N2bt26hRUrViAiIgLBwcF4+umnTa89cuQIZsyYgZCQEEyZMgVZWVlWf266jcHi4L799ltMnjwZkydPxk8//YTLly8DAOrq6hATEwM/Pz9kZmZi//79mDRpEu655x4sW7YMQUFByM3NbfX0+OZkZmbiscceQ3Z2NiIjI5GYmAgA+Oc//4mQkBAsXrwYubm5GDBggFXbfeedd1BQUID09HTs2rULpaWlZtPvL1++jOvXr2P//v146623kJCQgKtXrwKoP9HwxIkT+PLLL/HLL7/g1VdfhYuLC0pKShATE4MXX3wRv/zyCxYuXIh58+apZiU5e8RgcWDZ2dnQ6XR49NFHERgYCH9/f2zfvh0AcOzYMZSWlmLBggXw9PSEh4dHm3onzQkODkZ4eDi0Wi2io6ORn5/f7m0KgoCtW7di0aJF8Pb2RqdOnRATE4MdO3aYnuPq6oqXXnoJbm5uCA8Ph6enJwoKCmA0GvH111/j9ddfN60Hc//998Pd3R3ffvstxo4di/DwcLi4uGDMmDEIDAzEvn372l2zs+JRIQeWnp6OMWPGmE46jIqKQlpaGp577jkUFRWZToeXQ+MlGTp06ACDwYDa2tp2vV9FRQVu3ryJadOmme4TBMFsFTxvb2+z9+jYsSOqqqpQWVkJg8HQ5OJJOp0OGRkZ2Lt3r+m+2tpas5XlyDoMFgd169YtfP/99zAajRgzZgyA+vGJa9euIT8/H76+vigqKmpyZ1dyOcSOHTuaxj0AoKyszPT/Pj4+6NChA3bs2IFevXpZtV0fHx94eHigsLDwjuuC+/r6Ijo6Gm+++Wb7iicT/hRyUD/88AO0Wi127NiB9PR0pKen47vvvkNISAjS09MxfPhw9OjRA6tWrUJVVRUMBgNycnIAAN26dUNJSYnow9NSGjx4MM6ePYtTp07BYDBgzZo1psdcXFzwxBNPYPny5SgvLwcAlJSU4MCBA61u18XFBX/84x+RlJRkWmgqNzcX1dXVmDJlCvbu3YsDBw6grq4OBoMBWVlZKC4ulu1zOjoGi4NKS0vDtGnT4Ofnhx49epj+e/bZZ7Ft2zYIgoCUlBRcvHgRERERGDt2LL7//nsA9eu23HvvvQgLCzP9HEhJScHzzz8ve90DBgzASy+9hOeeew6PPPIIgoODzR5/9dVX0a9fPzz55JO4//778dxzz91xZKk5CxcuREBAAKZPn44HHngAK1euhNFohK+vLz766CN88sknGD16NMLDw/HZZ5/ZfKFxR8L1WIhIcuyxEJHkGCxEJDkGCxFJjsFCRJJjsBCR5BgsRCQ5BgsRSY7BQkSSY7AQkeT+P6XhCrlgscY4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "low_limit = -10\n",
    "up_limit = 10\n",
    "\n",
    "sns.set_theme()\n",
    "sns.set_style(\"whitegrid\")\n",
    "low_limit = -6\n",
    "up_limit = 6\n",
    "x = np.linspace(low_limit, up_limit)\n",
    "plt.plot(x, x, color=\"grey\", alpha=0.25, zorder=0)\n",
    "plt.scatter(acctual_influence_1, predict_influence_1, s = 20, label = 'A', linewidths=0)\n",
    "plt.ticklabel_format(style=\"sci\", scilimits=(-4, 4))\n",
    "plt.axis('square')\n",
    "plt.xlabel('Act. Influence')\n",
    "plt.ylabel('Pred. Influence')\n",
    "plt.ylim(low_limit, up_limit)\n",
    "plt.xlim(low_limit, up_limit)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8c09c366",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.905866446703341, pvalue=1.0489816745401113e-59)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(np.array(acctual_influence_1)[np.array(acctual_influence_1) !=0], np.array(predict_influence_1)[np.array(predict_influence_1) !=0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1c0da19c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SpearmanrResult(correlation=0.905866446703341, pvalue=1.0489816745401113e-59)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.spearmanr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9141db14",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.4017750050418414, 1.833626249096279e-07)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scipy.stats.pearsonr(acctual_influence_1, predict_influence_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b49271c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data1 = pd.DataFrame([time_infl[i] / time_retrain[i] for i in range(len(time_infl))])\n",
    "# data1.to_csv('running time/citeseer_running_time.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "32fc690d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.DataFrame([acctual_influence_1, predict_influence_1]).T\n",
    "# df.columns = ['acctual_influence', 'predict_influence']\n",
    "# df.to_csv('wiki-cs-dataset/' + 'complete_node_influence' + '.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "83998f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = pd.DataFrame([acctual_influence_1, predict_influence_1]).T\n",
    "# df.columns = ['acctual_influence', 'predict_influence']\n",
    "# df.to_csv('amazon_dataset/' + 'computer_complete_node_influence' + '.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f245be0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame([acctual_influence_1, predict_influence_1]).T\n",
    "df.columns = ['acctual_influence', 'predict_influence']\n",
    "df.to_csv('amazon_dataset/' + 'photo_complete_node_influence' + '.csv', index = False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
