{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cacb6c70",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n",
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from train import train_preprocessed_data\n",
    "from model_edge_influence import EdgeInfluenceSGC\n",
    "from tqdm import tqdm\n",
    "from calculate_edge_influence import generate_edge_influence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b7ff0d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_reg = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10]\n",
    "max_edge_removeed = 200\n",
    "dataname = 'cora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "19a134ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f2b75812",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 139/139 [00:00<00:00, 210.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13264\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 13264/13264 [05:04<00:00, 43.51it/s]\n"
     ]
    }
   ],
   "source": [
    "# f_l, t_l, pred_infl, _ = generate_edge_influence(graph, feat, labels, train_mask, \n",
    "#                                                  val_mask, test_mask, number_classes, \n",
    "#                                                  l2_regularlization_term = l2_reg[0], num_layer = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8ac2ab1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_infl = pd.DataFrame([np.array(f_l), np.array(t_l), pred_infl]).T\n",
    "# df_infl.columns = ['from_index', 'to_index', 'pred_infl']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a22cf6d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_infl.to_csv('improve_acc/cora_temp.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "91c6bd6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_infl = df_infl.loc[df_infl['pred_infl'] < 0]\n",
    "# df_infl = df_infl.sort_values(['pred_infl'])\n",
    "# df_infl.index = range(len(df_infl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "acb22935",
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_reg = [0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10]\n",
    "num_edge = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7a671457",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_acc_val = 0.0\n",
    "acc_test = 0.0\n",
    "l2_list_all = []\n",
    "num_edge_removed_list_all = []\n",
    "acc_all = []\n",
    "acc2_all = []\n",
    "\n",
    "l2_list = []\n",
    "num_edge_removed_list = []\n",
    "acc = []\n",
    "acc2 = []\n",
    "\n",
    "\n",
    "for i in tqdm(range(len(l2_reg))):\n",
    "    f_l, t_l, pred_infl, _ = generate_edge_influence(graph, feat, labels, train_mask, \n",
    "                                                     val_mask, test_mask, number_classes, \n",
    "                                                     l2_regularlization_term = l2_reg[i], num_layer = 2)\n",
    "    df_infl = pd.DataFrame([np.array(f_l), np.array(t_l), pred_infl]).T\n",
    "    df_infl.columns = ['from_index', 'to_index', 'pred_infl']\n",
    "    \n",
    "    for j in range(num_edge):\n",
    "        from_index = df_infl.loc[0:j]['from_index'].values.astype(int)\n",
    "        to_index = df_infl.loc[0:j]['to_index'].values.astype(int)\n",
    "\n",
    "\n",
    "        nis = EdgeInfluenceSGC(graph = graph, feature=feat_cora, \n",
    "                               from_index=from_index_cora, to_index=to_index_cora)\n",
    "        nis.remove_edges_sgc_from_influence()\n",
    "        feat_removed = nis_cora.calculate_modified_features()\n",
    "\n",
    "        lr = SimplifiedGraphNeuralNetwork(l2_reg=l2_reg[i], fit_intercept=True)\n",
    "        train_x = feat_removed[train_mask == 1].numpy()\n",
    "        lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "        acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "        acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "        num_edge_removed_list.append(j)\n",
    "        l2_list.append(l2_reg[i])\n",
    "        \n",
    "        if np.mean(lr.model.predict(val_x) == val_y) > best_acc_val:\n",
    "            best_acc_val = np.mean(lr.model.predict(val_x) == val_y)\n",
    "            print('best accuracy on val set: ', best_acc_val)\n",
    "            print('corresponding accuracy on test set: ', np.mean(lr.model.predict(test_x) == test_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a99b83a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_actual_edge_influence(df_edge):\n",
    "    new_infl = []\n",
    "    new_infl_pred = []\n",
    "    df_edge_copy = df_edge.copy()\n",
    "    for i in tqdm(range(len(df_edge))):\n",
    "        f = df_edge.loc[i, ['from_edges']].values[0]\n",
    "        t = df_edge.loc[i, ['to_edges']].values[0]\n",
    "\n",
    "        act_1 = df_edge[(df_edge['from_edges'] == f) & (df_edge['to_edges'] == t)].actual_influence.values[0]\n",
    "        act_2 = df_edge[(df_edge['from_edges'] == t) & (df_edge['to_edges'] == f)].actual_influence.values[0]\n",
    "        \n",
    "        pred_1 = df_edge[(df_edge['from_edges'] == f) & (df_edge['to_edges'] == t)].predicted_influence.values[0]\n",
    "        pred_2 = df_edge[(df_edge['from_edges'] == t) & (df_edge['to_edges'] == f)].predicted_influence.values[0]\n",
    "        \n",
    "        assert (act_1 == act_2)\n",
    "        assert (pred_1 == pred_2)\n",
    "        new_infl.append(act_1)\n",
    "        new_infl_pred.append(pred_1)\n",
    "    return new_infl, new_infl_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62083c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2297f2b0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d1cc32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7adf2a30",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
