{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "921ab8c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from dgl import function as fn\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from dataset import load_graph_dataset\n",
    "from tqdm import tqdm\n",
    "from model_softmax import SimplifiedGraphNeuralNetwork\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from train import train_preprocessed_data\n",
    "from model_edge_influence import EdgeInfluenceSGC\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1ba733c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn, config_context\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3bfeef12",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataname = 'reddit'\n",
    "graph, feat, labels, train_mask, val_mask, test_mask, number_classes = load_graph_dataset(dataname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6ab4e15d",
   "metadata": {},
   "outputs": [],
   "source": [
    "feat0 = feat.clone()\n",
    "degs = graph.in_degrees().float().clamp(min = 1)\n",
    "norm = torch.pow(degs, -0.5)\n",
    "norm = norm.to(feat0.device).unsqueeze(1)\n",
    "\n",
    "for _ in range(2):\n",
    "    feat0 = feat0 * norm\n",
    "    graph.ndata['h'] = feat0\n",
    "    graph.update_all(fn.copy_u('h', 'm'),\n",
    "                     fn.sum('m', 'h'))\n",
    "    feat0 = graph.ndata.pop('h')\n",
    "    feat0 = feat0 * norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5b4153b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = feat0[train_mask].numpy().astype(np.float64)\n",
    "train_y = labels[train_mask].numpy().astype(np.float64)\n",
    "\n",
    "val_x = feat0[val_mask].numpy().astype(np.float64)\n",
    "val_y = labels[val_mask].numpy().astype(np.float64)\n",
    "\n",
    "test_x = feat0[test_mask].numpy().astype(np.float64)\n",
    "test_y = labels[test_mask].numpy().astype(np.float64)\n",
    "\n",
    "train_node_idx = torch.where(train_mask == 1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "63889070",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = OneHotEncoder(handle_unknown='ignore')\n",
    "enc.fit(train_y.reshape(-1, 1))\n",
    "\n",
    "one_hot_labels_train = enc.transform(train_y.reshape(-1, 1)).toarray()\n",
    "one_hot_labels_test = enc.transform(test_y.reshape(-1, 1)).toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "828bb468",
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_reg = [1， 1.5, ]\n",
    "# tol = [1e-2, 1e-3, 5e-4, 1e-4, 5e-5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7ec91662",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                     | 0/7 [00:00<?, ?it/s]/home/zizhang/anaconda3/lib/python3.8/site-packages/daal4py/sklearn/linear_model/logistic_path.py:459: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      " 14%|██████▎                                     | 1/7 [09:03<54:23, 543.92s/it]/home/zizhang/anaconda3/lib/python3.8/site-packages/daal4py/sklearn/linear_model/logistic_path.py:459: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "100%|████████████████████████████████████████████| 7/7 [37:33<00:00, 321.97s/it]\n"
     ]
    }
   ],
   "source": [
    "acc = []\n",
    "acc2 = []\n",
    "l = []\n",
    "t = []\n",
    "for i in tqdm(l2_reg):\n",
    "    lr = SimplifiedGraphNeuralNetwork(l2_reg=i, fit_intercept=True)\n",
    "    lr.fit(train_x, train_y, sample_weight=None, verbose=False)\n",
    "    acc.append(np.mean(lr.model.predict(test_x) == test_y))\n",
    "    acc2.append(np.mean(lr.model.predict(val_x) == val_y))\n",
    "    l.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5566a557",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.9321580525285891,\n",
       " 0.9361434752167747,\n",
       " 0.9383695671687342,\n",
       " 0.943252607579484,\n",
       " 0.9451555571513204,\n",
       " 0.947758648546757,\n",
       " 0.947830457964562]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d77c522b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.9334060677269104,\n",
       " 0.9364693046871722,\n",
       " 0.9388191850950443,\n",
       " 0.9439805295623348,\n",
       " 0.9462464856699258,\n",
       " 0.947883009525408,\n",
       " 0.9483865553270949]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23ba4db4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "925f0612",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
