{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "3ffd5ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "# import daal4py as d4p\n",
    "# from daal4py.sklearn import patch_sklearn\n",
    "# from daal4py.oneapi import sycl_context\n",
    "import dgl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "dc289892",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1627e9b4",
   "metadata": {},
   "source": [
    "##### Tune hyperparameter on validation set first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "c602cfe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataname = 'cora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "07ce87d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  NumNodes: 2708\n",
      "  NumEdges: 10556\n",
      "  NumFeats: 1433\n",
      "  NumClasses: 7\n",
      "  NumTrainingSamples: 140\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n",
      "Done loading data from cached files.\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "a3d0213e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([   3,    5,    6,  ..., 2705, 2706, 2707]),\n",
       " tensor([   0,    0,    0,  ..., 2705, 2706, 2707]))"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "cac0bb5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([   3,    7,   24,   34,   52,   65,   67,   69,   70,   72,   74,   79,\n",
       "         130,  138,  140,  168,  184,  186,  240,  242,  266,  310,  328,  346,\n",
       "         352,  403,  417,  419,  429,  431,  432,  452,  468,  479,  484,  501,\n",
       "         542,  554,  575,  580,  592,  602,  623,  624,  637,  645,  666,  736,\n",
       "         741,  760,  776,  812,  818,  821,  845,  850,  859,  866,  899,  910,\n",
       "         925, 1002, 1047, 1049, 1082, 1101, 1162, 1209, 1230, 1281, 1286, 1308,\n",
       "        1327, 1337, 1343, 1372, 1410, 1415, 1428, 1454, 1487, 1488, 1516, 1544,\n",
       "        1587, 1682, 1683, 1716, 1754, 1793, 1794, 1795, 1819, 1826, 1852, 1879,\n",
       "        1898, 1921, 1924, 1938, 1939, 1941, 1958, 1981, 1985, 2001, 2004, 2013,\n",
       "        2020, 2029, 2032, 2064, 2068, 2079, 2141, 2156, 2160, 2169, 2208, 2209,\n",
       "        2223, 2224, 2226, 2232, 2287, 2296, 2337, 2367, 2401, 2445, 2468, 2469,\n",
       "        2518, 2528, 2543, 2569, 2658, 2669, 2673, 2681])"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "77e357ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "70fa70c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float64)\n",
    "train_y = labels[train_mask].numpy().astype(np.float64)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float64)\n",
    "val_y = labels[val_mask].numpy().astype(np.float64)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float64)\n",
    "test_y = labels[test_mask].numpy().astype(np.float64)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "38302aec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' Train Logistic Regression '"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()\n",
    "\n",
    "\"\"\" Train Logistic Regression \"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "e0d54e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_reg = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10]\n",
    "# l2_reg_epoch2 = np.arange(0.001, 0.01, 0.001)\n",
    "l2_reg_epoch2 = [0.001, 0.01, 0.1, 1, 10, 100]\n",
    "tol = [1e-4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "c45e4c9a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(140, 1433)"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "42228f0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 6/6 [00:00<00:00, 17.21it/s]\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "acc2 = []\n",
    "l = []\n",
    "t = []\n",
    "for i in tqdm(l2_reg_epoch2):\n",
    "    for j in tol:\n",
    "        lr = SimplifiedGraphNeuralNetwork(l2_reg=i, tol = j, fit_intercept=True)\n",
    "        lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "        acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "        acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "        l.append(i)\n",
    "        t.append(j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "bb75f2f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.806\n",
      "0.792\n",
      "0.001\n",
      "0.0001\n"
     ]
    }
   ],
   "source": [
    "idx = np.where(acc2 == max(acc2))[0][0]\n",
    "print(np.array(acc)[idx])\n",
    "print(np.array(acc2)[idx])\n",
    "print(np.array(l)[idx])\n",
    "print(np.array(t)[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "92d3e86b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 99/99 [1:10:52<00:00, 42.96s/it]\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "acc2 = []\n",
    "l = []\n",
    "t = []\n",
    "for i in tqdm(l2_reg):\n",
    "\n",
    "#     with config_context(target_offload=\"gpu:0\"):\n",
    "\n",
    "    lr = LogisticRegression(C=1 / i, fit_intercept=True, max_iter=2048)\n",
    "    lr.fit(train_x, train_y, sample_weight=None)\n",
    "#     acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "#     acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "\n",
    "    acc.append(np.mean(lr.predict(test_x) == test_y))\n",
    "    acc2.append(np.mean(lr.predict(val_x) == val_y))\n",
    "    l.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "45436dee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.94770479, 0.94770479, 0.94779455])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(acc)[np.where(acc2 == max(acc2))[0]] # 0.94770479"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "ff82b97e",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_tune1 = np.where(acc2 == max(acc2))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "8316a71d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame([acc, acc2]).T.to_csv('result_data/reddit_acc.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "c2ecfea7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.795]\n",
      "[0.806]\n",
      "[0.001]\n",
      "[0.01]\n"
     ]
    }
   ],
   "source": [
    "print(np.array(acc)[index_tune1])\n",
    "print(np.array(acc2)[index_tune1])\n",
    "print(np.array(l)[index_tune1])\n",
    "print(np.array(t)[index_tune1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b018dd22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.01,\n",
       " 0.02,\n",
       " 0.03,\n",
       " 0.04,\n",
       " 0.05,\n",
       " 0.060000000000000005,\n",
       " 0.06999999999999999,\n",
       " 0.08,\n",
       " 0.09,\n",
       " 0.09999999999999999,\n",
       " 0.11,\n",
       " 0.12,\n",
       " 0.13,\n",
       " 0.14,\n",
       " 0.15000000000000002,\n",
       " 0.16,\n",
       " 0.17,\n",
       " 0.18000000000000002,\n",
       " 0.19,\n",
       " 0.2,\n",
       " 0.21000000000000002,\n",
       " 0.22,\n",
       " 0.23,\n",
       " 0.24000000000000002,\n",
       " 0.25,\n",
       " 0.26,\n",
       " 0.27,\n",
       " 0.28,\n",
       " 0.29000000000000004,\n",
       " 0.3,\n",
       " 0.31,\n",
       " 0.32,\n",
       " 0.33,\n",
       " 0.34,\n",
       " 0.35000000000000003,\n",
       " 0.36000000000000004,\n",
       " 0.37,\n",
       " 0.38,\n",
       " 0.39,\n",
       " 0.4,\n",
       " 0.41000000000000003,\n",
       " 0.42000000000000004,\n",
       " 0.43,\n",
       " 0.44,\n",
       " 0.45,\n",
       " 0.46,\n",
       " 0.47000000000000003,\n",
       " 0.48000000000000004,\n",
       " 0.49,\n",
       " 0.5,\n",
       " 0.51,\n",
       " 0.52,\n",
       " 0.53,\n",
       " 0.54,\n",
       " 0.55,\n",
       " 0.56,\n",
       " 0.5700000000000001,\n",
       " 0.5800000000000001,\n",
       " 0.59,\n",
       " 0.6,\n",
       " 0.61,\n",
       " 0.62,\n",
       " 0.63,\n",
       " 0.64,\n",
       " 0.65,\n",
       " 0.66,\n",
       " 0.67,\n",
       " 0.68,\n",
       " 0.6900000000000001,\n",
       " 0.7000000000000001,\n",
       " 0.7100000000000001,\n",
       " 0.72,\n",
       " 0.73,\n",
       " 0.74,\n",
       " 0.75,\n",
       " 0.76,\n",
       " 0.77,\n",
       " 0.78,\n",
       " 0.79,\n",
       " 0.8,\n",
       " 0.81,\n",
       " 0.8200000000000001,\n",
       " 0.8300000000000001,\n",
       " 0.8400000000000001,\n",
       " 0.85,\n",
       " 0.86,\n",
       " 0.87,\n",
       " 0.88,\n",
       " 0.89,\n",
       " 0.9,\n",
       " 0.91,\n",
       " 0.92,\n",
       " 0.93,\n",
       " 0.9400000000000001,\n",
       " 0.9500000000000001,\n",
       " 0.9600000000000001,\n",
       " 0.97,\n",
       " 0.98,\n",
       " 0.99]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "25ba6b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ = pd.DataFrame([acc, acc2, l, t]).T\n",
    "df_.to_csv(dataname +'_tune.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bd651968",
   "metadata": {},
   "outputs": [],
   "source": [
    "actual_influence_weight = pd.read_csv('result_tune/' + dataname + '_tune.csv', header = None)[0].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "51b09ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_term = np.array(l)[index_tune1][0]\n",
    "tol_term = np.array(t)[index_tune1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "015ccdbe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40.0"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape\n",
    "np.max(train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "09135b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# acc_val_w1 = []\n",
    "# acc_test_w2 = []\n",
    "# idx1_list = []\n",
    "# idx2_list = []\n",
    "# value1_down_arr = np.arange(0.5, 1, 0.01)\n",
    "# value2_up_arr = np.arange(1, 1.5, 0.01)\n",
    "\n",
    "# for idx1 in tqdm(range(len(value1_down_arr))):\n",
    "#     for idx2 in range(len(value2_up_arr)):\n",
    "#         w1 = np.ones(len(train_x))\n",
    "#         w1[np.where(actual_influence_weight < 0)[0]] = value1_down_arr[idx1]\n",
    "#         w1[np.where(actual_influence_weight > 0)[0]] = value2_up_arr[idx2]\n",
    "#         lr_tune = LogisticRegression(C = 1 / reg_term, tol = tol_term, fit_intercept=True)\n",
    "#         lr_tune.fit(train_x, train_y, w1)\n",
    "        \n",
    "#         acc_val_w1.append(np.mean(lr_tune.predict(val_x) == val_y))\n",
    "#         acc_test_w2.append(np.mean(lr_tune.predict(test_x) == test_y))\n",
    "#         idx1_list.append(value1_down_arr[idx1])\n",
    "#         idx2_list.append(value2_up_arr[idx2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3834757",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(acc_val_w1)\n",
    "np.max(acc_test_w2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3302b3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_val_highest = np.where(np.max(acc_val_w1) == acc_val_w1)[0]\n",
    "np.array(acc_test_w2)[idx_val_highest]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5a94e17",
   "metadata": {},
   "outputs": [],
   "source": [
    "scale_rate = np.arange(0.001, 0.1, 0.00001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf58387b",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_val_w3 = []\n",
    "acc_test_w4 = []\n",
    "for idx1 in tqdm(range(len(scale_rate))):\n",
    "    w1 = 1 + scale_rate[idx1] * actual_influence_weight\n",
    "    lr_tune = LogisticRegression(C = 1 / reg_term, tol = tol_term, fit_intercept=True)\n",
    "    lr_tune.fit(train_x, train_y, w1)\n",
    "\n",
    "    acc_val_w3.append(np.mean(lr_tune.predict(val_x) == val_y))\n",
    "    acc_test_w4.append(np.mean(lr_tune.predict(test_x) == test_y))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3887217b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_tune = pd.DataFrame([acc_val_w3, acc_test_w4, scale_rate]).T\n",
    "df_tune.to_csv('result_tune/df_'+ dataname +'_tune2_scale.csv', header = None, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0a6b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_val_highest_2 = np.where(np.max(acc_val_w3) == acc_val_w3)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76bf7d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(acc_test_w4)[idx_val_highest_2].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19046cde",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(acc_test_w4)[idx_val_highest_2].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36856edf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
