{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "da4c808a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import networkx as nx\n",
    "import os\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from dataset_wikics import load_wikics\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from train import train_preprocessed_data\n",
    "from model_edge_influence import EdgeInfluenceSGC\n",
    "from tqdm import tqdm\n",
    "from dgl.data import AmazonCoBuyComputerDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "13694360",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----Data statistics------'\n",
      "      #Edges 297110\n",
      "      #Classes 10\n",
      "      #Train samples 580\n",
      "      #Val samples 1769\n",
      "      #Stopping samples 3505\n",
      "      #Test samples 5847\n"
     ]
    }
   ],
   "source": [
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_wikics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b4bbfcfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "6faab705",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.87s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.92s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.86s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.99s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.88s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.91s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.86s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.90s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:06<00:00,  1.58s/it]\n",
      "100%|█████████████████████████████████████████████| 4/4 [00:07<00:00,  1.91s/it]\n"
     ]
    }
   ],
   "source": [
    "best_acc_all = []\n",
    "for i in range(10):\n",
    "    train_x = feat0[train_mask[i]].numpy().astype(np.float64)\n",
    "    train_y = labels[train_mask[i]].numpy().astype(np.float64)\n",
    "\n",
    "    val_x = feat0[val_mask[i]].numpy().astype(np.float64)\n",
    "    val_y = labels[val_mask[i]].numpy().astype(np.float64)\n",
    "\n",
    "    test_x = feat0[test_mask].numpy().astype(np.float64)\n",
    "    test_y = labels[test_mask].numpy().astype(np.float64)\n",
    "\n",
    "    train_node_idx = torch.where(train_mask[i] == 1)[0]\n",
    "    \n",
    "#     l2_reg = np.arange(0.0001, 0.001, 0.0001)\n",
    "    l2_reg = [0.0001, 0.001, 0.01, 1]\n",
    "    tol = [1e-2, 1e-3, 5e-4, 1e-4, 5e-5]\n",
    "    # grid search\n",
    "    \n",
    "    for i in tqdm(l2_reg):\n",
    "        for j in tol:\n",
    "            lr = SimplifiedGraphNeuralNetwork(l2_reg=i, tol = j, fit_intercept=True)\n",
    "            lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "            acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "            acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "            acc3.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "            l.append(i)\n",
    "            t.append(j)\n",
    "    idx = np.where(acc2 == max(acc2))[0]\n",
    "    best_l2_reg = np.array(l)[idx]\n",
    "    best_tol = np.array(t)[idx]\n",
    "    best_acc_val = np.array(acc2)[idx]\n",
    "    best_acc_test = np.array(acc)[idx]\n",
    "    best_acc_all.append(best_acc_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f99d26ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8003841549249451"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(best_acc_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f7e796be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0003062832149773956"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.std(best_acc_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1b9a17ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008,\n",
       "       0.0009])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.arange(0.0001, 0.001, 0.0001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ac0164c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([0.79972636, 0.79972636, 0.79972636, 0.79972636, 0.80006841,\n",
       "        0.80006841, 0.80006841, 0.80006841, 0.80023944, 0.80023944,\n",
       "        0.80023944, 0.80023944, 0.80058149, 0.80041047, 0.80058149,\n",
       "        0.80075252, 0.80041047, 0.80041047, 0.80041047, 0.80041047,\n",
       "        0.80075252, 0.80075252, 0.80075252, 0.80075252, 0.80075252,\n",
       "        0.80075252, 0.80075252, 0.80075252, 0.80075252])]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_acc_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "766fa985",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([431, 432, 433, 434, 436, 437, 438, 439, 441, 442, 443, 444, 470,\n",
       "       475, 485, 490, 491, 492, 493, 494, 496, 497, 498, 499, 500, 501,\n",
       "       502, 503, 504])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.where(acc2 == max(acc2))[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
