{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "22bb5f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "# import daal4py as d4p\n",
    "# from daal4py.sklearn import patch_sklearn\n",
    "# from daal4py.oneapi import sycl_context\n",
    "import dgl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d1dc7c24",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e872d49",
   "metadata": {},
   "source": [
    "##### Tune hyperparameter on validation set first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c4986b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataname = 'cora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2d15d7e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "123551cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([   3,    5,    6,  ..., 2705, 2706, 2707]),\n",
       " tensor([   0,    0,    0,  ..., 2705, 2706, 2707]))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "0b15498f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([   3,    7,   24,   34,   52,   65,   67,   69,   70,   72,   74,   79,\n",
       "         130,  138,  140,  168,  184,  186,  240,  242,  266,  310,  328,  346,\n",
       "         352,  403,  417,  419,  429,  431,  432,  452,  468,  479,  484,  501,\n",
       "         542,  554,  575,  580,  592,  602,  623,  624,  637,  645,  666,  736,\n",
       "         741,  760,  776,  812,  818,  821,  845,  850,  859,  866,  899,  910,\n",
       "         925, 1002, 1047, 1049, 1082, 1101, 1162, 1209, 1230, 1281, 1286, 1308,\n",
       "        1327, 1337, 1343, 1372, 1410, 1415, 1428, 1454, 1487, 1488, 1516, 1544,\n",
       "        1587, 1682, 1683, 1716, 1754, 1793, 1794, 1795, 1819, 1826, 1852, 1879,\n",
       "        1898, 1921, 1924, 1938, 1939, 1941, 1958, 1981, 1985, 2001, 2004, 2013,\n",
       "        2020, 2029, 2032, 2064, 2068, 2079, 2141, 2156, 2160, 2169, 2208, 2209,\n",
       "        2223, 2224, 2226, 2232, 2287, 2296, 2337, 2367, 2401, 2445, 2468, 2469,\n",
       "        2518, 2528, 2543, 2569, 2658, 2669, 2673, 2681])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "741db368",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "9a4ec170",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float64)\n",
    "train_y = labels[train_mask].numpy().astype(np.float64)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float64)\n",
    "val_y = labels[val_mask].numpy().astype(np.float64)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float64)\n",
    "test_y = labels[test_mask].numpy().astype(np.float64)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "29a68b4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "086869e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "l2_reg = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10]\n",
    "# l2_reg_epoch2 = np.arange(0.001, 0.01, 0.001)\n",
    "# l2_reg = np.arange(0.0001, 0.001, 0.0001)\n",
    "tol = [1e-2, 1e-3, 5e-4, 1e-4, 5e-5]\n",
    "# tol = [1e-2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "33afa572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(140, 1433)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "aac0d019",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 9/9 [00:01<00:00,  4.51it/s]\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "acc2 = []\n",
    "l = []\n",
    "t = []\n",
    "for i in tqdm(l2_reg):\n",
    "    for j in tol:\n",
    "        lr = SimplifiedGraphNeuralNetwork(l2_reg=i, tol = j, fit_intercept=True)\n",
    "        lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "        acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "        acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "        l.append(i)\n",
    "        t.append(j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "4aa9ba2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.811 0.811 0.811 0.811]\n",
      "[0.794 0.794 0.794 0.794]\n",
      "[0.005 0.005 0.005 0.005]\n",
      "[1.e-03 5.e-04 1.e-04 5.e-05]\n"
     ]
    }
   ],
   "source": [
    "idx = np.where(acc2 == max(acc2))[0]\n",
    "print(np.array(acc)[idx])\n",
    "print(np.array(acc2)[idx])\n",
    "print(np.array(l)[idx])\n",
    "print(np.array(t)[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ad3ccbb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 99/99 [1:10:52<00:00, 42.96s/it]\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "acc2 = []\n",
    "l = []\n",
    "t = []\n",
    "for i in tqdm(l2_reg):\n",
    "\n",
    "#     with config_context(target_offload=\"gpu:0\"):\n",
    "\n",
    "    lr = LogisticRegression(C=1 / i, fit_intercept=True, max_iter=2048)\n",
    "    lr.fit(train_x, train_y, sample_weight=None)\n",
    "#     acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "#     acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "\n",
    "    acc.append(np.mean(lr.predict(test_x) == test_y))\n",
    "    acc2.append(np.mean(lr.predict(val_x) == val_y))\n",
    "    l.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "0c1365a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.94770479, 0.94770479, 0.94779455])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(acc)[np.where(acc2 == max(acc2))[0]] # 0.94770479"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "656b4c88",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_tune1 = np.where(acc2 == max(acc2))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "47079aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame([acc, acc2]).T.to_csv('result_data/reddit_acc.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "ae496ee7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.795]\n",
      "[0.806]\n",
      "[0.001]\n",
      "[0.01]\n"
     ]
    }
   ],
   "source": [
    "print(np.array(acc)[index_tune1])\n",
    "print(np.array(acc2)[index_tune1])\n",
    "print(np.array(l)[index_tune1])\n",
    "print(np.array(t)[index_tune1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "3ab83916",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.01,\n",
       " 0.02,\n",
       " 0.03,\n",
       " 0.04,\n",
       " 0.05,\n",
       " 0.060000000000000005,\n",
       " 0.06999999999999999,\n",
       " 0.08,\n",
       " 0.09,\n",
       " 0.09999999999999999,\n",
       " 0.11,\n",
       " 0.12,\n",
       " 0.13,\n",
       " 0.14,\n",
       " 0.15000000000000002,\n",
       " 0.16,\n",
       " 0.17,\n",
       " 0.18000000000000002,\n",
       " 0.19,\n",
       " 0.2,\n",
       " 0.21000000000000002,\n",
       " 0.22,\n",
       " 0.23,\n",
       " 0.24000000000000002,\n",
       " 0.25,\n",
       " 0.26,\n",
       " 0.27,\n",
       " 0.28,\n",
       " 0.29000000000000004,\n",
       " 0.3,\n",
       " 0.31,\n",
       " 0.32,\n",
       " 0.33,\n",
       " 0.34,\n",
       " 0.35000000000000003,\n",
       " 0.36000000000000004,\n",
       " 0.37,\n",
       " 0.38,\n",
       " 0.39,\n",
       " 0.4,\n",
       " 0.41000000000000003,\n",
       " 0.42000000000000004,\n",
       " 0.43,\n",
       " 0.44,\n",
       " 0.45,\n",
       " 0.46,\n",
       " 0.47000000000000003,\n",
       " 0.48000000000000004,\n",
       " 0.49,\n",
       " 0.5,\n",
       " 0.51,\n",
       " 0.52,\n",
       " 0.53,\n",
       " 0.54,\n",
       " 0.55,\n",
       " 0.56,\n",
       " 0.5700000000000001,\n",
       " 0.5800000000000001,\n",
       " 0.59,\n",
       " 0.6,\n",
       " 0.61,\n",
       " 0.62,\n",
       " 0.63,\n",
       " 0.64,\n",
       " 0.65,\n",
       " 0.66,\n",
       " 0.67,\n",
       " 0.68,\n",
       " 0.6900000000000001,\n",
       " 0.7000000000000001,\n",
       " 0.7100000000000001,\n",
       " 0.72,\n",
       " 0.73,\n",
       " 0.74,\n",
       " 0.75,\n",
       " 0.76,\n",
       " 0.77,\n",
       " 0.78,\n",
       " 0.79,\n",
       " 0.8,\n",
       " 0.81,\n",
       " 0.8200000000000001,\n",
       " 0.8300000000000001,\n",
       " 0.8400000000000001,\n",
       " 0.85,\n",
       " 0.86,\n",
       " 0.87,\n",
       " 0.88,\n",
       " 0.89,\n",
       " 0.9,\n",
       " 0.91,\n",
       " 0.92,\n",
       " 0.93,\n",
       " 0.9400000000000001,\n",
       " 0.9500000000000001,\n",
       " 0.9600000000000001,\n",
       " 0.97,\n",
       " 0.98,\n",
       " 0.99]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "12e5645f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ = pd.DataFrame([acc, acc2, l, t]).T\n",
    "df_.to_csv(dataname +'_tune.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "23259e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "actual_influence_weight = pd.read_csv('result_tune/' + dataname + '_tune.csv', header = None)[0].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d5bebb3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_term = np.array(l)[index_tune1][0]\n",
    "tol_term = np.array(t)[index_tune1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "ed79511c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40.0"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape\n",
    "np.max(train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6cfa3645",
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc_val_w1 = []\n",
    "# acc_test_w2 = []\n",
    "# idx1_list = []\n",
    "# idx2_list = []\n",
    "# value1_down_arr = np.arange(0.5, 1, 0.01)\n",
    "# value2_up_arr = np.arange(1, 1.5, 0.01)\n",
    "\n",
    "# for idx1 in tqdm(range(len(value1_down_arr))):\n",
    "#     for idx2 in range(len(value2_up_arr)):\n",
    "#         w1 = np.ones(len(train_x))\n",
    "#         w1[np.where(actual_influence_weight < 0)[0]] = value1_down_arr[idx1]\n",
    "#         w1[np.where(actual_influence_weight > 0)[0]] = value2_up_arr[idx2]\n",
    "#         lr_tune = LogisticRegression(C = 1 / reg_term, tol = tol_term, fit_intercept=True)\n",
    "#         lr_tune.fit(train_x, train_y, w1)\n",
    "        \n",
    "#         acc_val_w1.append(np.mean(lr_tune.predict(val_x) == val_y))\n",
    "#         acc_test_w2.append(np.mean(lr_tune.predict(test_x) == test_y))\n",
    "#         idx1_list.append(value1_down_arr[idx1])\n",
    "#         idx2_list.append(value2_up_arr[idx2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a029291",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(acc_val_w1)\n",
    "np.max(acc_test_w2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c223be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_val_highest = np.where(np.max(acc_val_w1) == acc_val_w1)[0]\n",
    "np.array(acc_test_w2)[idx_val_highest]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae92a14f",
   "metadata": {},
   "outputs": [],
   "source": [
    "scale_rate = np.arange(0.001, 0.1, 0.00001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4081df",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_val_w3 = []\n",
    "acc_test_w4 = []\n",
    "for idx1 in tqdm(range(len(scale_rate))):\n",
    "    w1 = 1 + scale_rate[idx1] * actual_influence_weight\n",
    "    lr_tune = LogisticRegression(C = 1 / reg_term, tol = tol_term, fit_intercept=True)\n",
    "    lr_tune.fit(train_x, train_y, w1)\n",
    "\n",
    "    acc_val_w3.append(np.mean(lr_tune.predict(val_x) == val_y))\n",
    "    acc_test_w4.append(np.mean(lr_tune.predict(test_x) == test_y))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d5e8d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_tune = pd.DataFrame([acc_val_w3, acc_test_w4, scale_rate]).T\n",
    "df_tune.to_csv('result_tune/df_'+ dataname +'_tune2_scale.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "512eef28",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_val_highest_2 = np.where(np.max(acc_val_w3) == acc_val_w3)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d23de6a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(acc_test_w4)[idx_val_highest_2].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01973fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(acc_test_w4)[idx_val_highest_2].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73657d4c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
