{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Optimization Finished! Best Test Result: 0.6128, Training Loss: 0.5955\n",
      "Optimization Finished! Best Test Result: 0.6184, Training Loss: 0.6149\n",
      "Optimization Finished! Best Test Result: 0.6234, Training Loss: 0.5943\n",
      "Optimization Finished! Best Test Result: 0.6207, Training Loss: 0.6042\n",
      "Optimization Finished! Best Test Result: 0.6251, Training Loss: 0.5870\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset deezer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "(array([   11,    12,    16, ..., 18620, 18642, 18653]),)\n",
      "[-1  0  1]\n",
      "[0 1 2]\n",
      "Traceback (most recent call last):\n",
      "  File \"train.py\", line 56, in <module>\n",
      "    adj, adj_high, features, labels = full_load_data(args.dataset_name, args.sub_dataname)\n",
      "  File \"/home/will/Desktop/GCN/utils.py\", line 280, in full_load_data\n",
      "    assert (np.array_equal(np.unique(labels), np.arange(len(np.unique(labels)))))\n",
      "AssertionError\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset facebook --sub Cornell5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "[3 0 2 3 4 3 0 0 3 0 3 3 3 3 3 4 3 3 0 3 0 3 3 3 3 1 3 3 0 2 3 3 4 3 4 4 2\n",
      " 3 3 3 0 4 0 3 3 3 2 2 0 3 0 3 3 3 3 3 0 2 2 4 4 4 3 3 3 3 0 3 3 3 4 3 3 4\n",
      " 4 3 0 3 0 3 4 3 2 4 2 4 3 3 0 3 3 3 0 3 3 4 3 3 3 4 0 0 4 3 3 0 4 3 2 3 3\n",
      " 0 3 0 0 3 3 4 3 3 3 0 3 0 3 2 4 2 3 3 0 4 3 4 3 4 3 3 0 3 2 3 3 3 3 3 3 2\n",
      " 3 4 3 4 3 3 3 3 0 2 0 2 3 3 4 3 0 3 3 2 0 3 3 4 3 2 0 3 0 4 3 3 4 3 3]\n",
      "/pytorch/torch/csrc/utils/python_arg_parser.cpp:756: UserWarning: This overload of nonzero is deprecated:\n",
      "\tnonzero(Tensor input, *, Tensor out)\n",
      "Consider using one of the following signatures instead:\n",
      "\tnonzero(Tensor input, *, bool as_tuple)\n",
      "/pytorch/aten/src/ATen/native/BinaryOps.cpp:81: UserWarning: Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.\n",
      "Optimization Finished! Best Test Result: 0.6230, Training Loss: 1.3329\n",
      "Optimization Finished! Best Test Result: 0.7213, Training Loss: 1.3822\n",
      "Optimization Finished! Best Test Result: 0.6721, Training Loss: 1.3573\n",
      "^C\n",
      "Traceback (most recent call last):\n",
      "  File \"train.py\", line 141, in <module>\n",
      "    loss_train.backward()\n",
      "  File \"/home/will/.local/lib/python3.8/site-packages/torch/tensor.py\", line 198, in backward\n",
      "    torch.autograd.backward(self, gradient, retain_graph, create_graph)\n",
      "  File \"/home/will/.local/lib/python3.8/site-packages/torch/autograd/__init__.py\", line 98, in backward\n",
      "    Variable._execution_engine.run_backward(\n",
      "KeyboardInterrupt\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset cornell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "198493\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "2322336\n",
      "Optimization Finished! Best Test Result: 0.1988, Training Loss: 1.6068\n",
      "Total time elapsed: 65.9029s\n",
      "Optimization Finished! Best Test Result: 0.2267, Training Loss: 1.6107\n",
      "Total time elapsed: 140.7847s\n",
      "Optimization Finished! Best Test Result: 0.3948, Training Loss: 1.6088\n",
      "Total time elapsed: 216.1823s\n",
      "Optimization Finished! Best Test Result: 0.3948, Training Loss: 1.6029\n",
      "Total time elapsed: 295.4667s\n",
      "Optimization Finished! Best Test Result: 0.5946, Training Loss: 1.6073\n",
      "Total time elapsed: 370.1925s\n",
      "Optimization Finished! Best Test Result: 0.3842, Training Loss: 1.6021\n",
      "Total time elapsed: 446.0324s\n",
      "Optimization Finished! Best Test Result: 0.2200, Training Loss: 1.6089\n",
      "Total time elapsed: 517.8973s\n",
      "Optimization Finished! Best Test Result: 0.5860, Training Loss: 1.6081\n",
      "Total time elapsed: 590.7854s\n",
      "Optimization Finished! Best Test Result: 0.3919, Training Loss: 1.6107\n",
      "Total time elapsed: 667.0666s\n",
      "Optimization Finished! Best Test Result: 0.3996, Training Loss: 1.6066\n",
      "Total time elapsed: 743.9513s\n",
      "Test Mean: 0.3792, Test Std: 0.1309\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset squirrel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset squirrel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "31421\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "444965\n",
      "Optimization Finished! Best Test Result: 0.6557, Training Loss: 1.6083\n",
      "Total time elapsed: 19.3116s\n",
      "Optimization Finished! Best Test Result: 0.5877, Training Loss: 1.6115\n",
      "Total time elapsed: 36.4617s\n",
      "Optimization Finished! Best Test Result: 0.8092, Training Loss: 1.6021\n",
      "Total time elapsed: 56.8758s\n",
      "Optimization Finished! Best Test Result: 0.6096, Training Loss: 1.6033\n",
      "Total time elapsed: 74.4862s\n",
      "Optimization Finished! Best Test Result: 0.7127, Training Loss: 1.6095\n",
      "Total time elapsed: 93.1951s\n",
      "Optimization Finished! Best Test Result: 0.4101, Training Loss: 1.6087\n",
      "Total time elapsed: 110.7431s\n",
      "Optimization Finished! Best Test Result: 0.6250, Training Loss: 1.5988\n",
      "Total time elapsed: 130.1281s\n",
      "Optimization Finished! Best Test Result: 0.8092, Training Loss: 1.5987\n",
      "Total time elapsed: 148.1195s\n",
      "Optimization Finished! Best Test Result: 0.6535, Training Loss: 1.6089\n",
      "Total time elapsed: 165.9896s\n",
      "Optimization Finished! Best Test Result: 0.4430, Training Loss: 1.6029\n",
      "Total time elapsed: 185.6924s\n",
      "Test Mean: 0.6316, Test Std: 0.1256\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset chameleon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "31421\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "444965\n",
      "Optimization Finished! Best Test Result: 0.9759, Training Loss: 0.0675\n",
      "Total time elapsed: 137.4342s\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset chameleon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "Optimization Finished! Best Test Result: 0.2785, Training Loss: 1.4386\n",
      "Total time elapsed: 18.1537s\n",
      "Optimization Finished! Best Test Result: 0.4956, Training Loss: 1.0364\n",
      "Total time elapsed: 47.8977s\n",
      "Optimization Finished! Best Test Result: 0.5373, Training Loss: 0.6683\n",
      "Total time elapsed: 90.8215s\n",
      "Optimization Finished! Best Test Result: 0.6776, Training Loss: 0.6083\n",
      "Total time elapsed: 173.9548s\n",
      "Optimization Finished! Best Test Result: 0.6820, Training Loss: 0.5595\n",
      "Total time elapsed: 329.1835s\n",
      "Optimization Finished! Best Test Result: 0.7237, Training Loss: 0.6134\n",
      "Total time elapsed: 708.5331s\n",
      "Optimization Finished! Best Test Result: 0.7390, Training Loss: 0.7291\n",
      "Total time elapsed: 1137.9224s\n",
      "Optimization Finished! Best Test Result: 0.7500, Training Loss: 0.8154\n",
      "Total time elapsed: 1644.2990s\n",
      "Optimization Finished! Best Test Result: 0.6930, Training Loss: 0.9515\n",
      "Total time elapsed: 2587.7906s\n",
      "Optimization Finished! Best Test Result: 0.7368, Training Loss: 0.9778\n",
      "Total time elapsed: 3228.4496s\n",
      "Test Mean: 0.6314, Test Std: 0.1434\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset chameleon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Optimization Finished! Best Test Result: 0.7451, Training Loss: 0.4887\n",
      "Total time elapsed: 74.4992s\n",
      "Optimization Finished! Best Test Result: 0.8627, Training Loss: 0.1802\n",
      "Total time elapsed: 83.7329s\n",
      "Optimization Finished! Best Test Result: 0.9020, Training Loss: 0.2842\n",
      "Total time elapsed: 104.9682s\n",
      "Optimization Finished! Best Test Result: 0.9020, Training Loss: 0.1923\n",
      "Total time elapsed: 128.4146s\n",
      "Optimization Finished! Best Test Result: 0.7843, Training Loss: 0.3592\n",
      "Total time elapsed: 133.8676s\n",
      "Optimization Finished! Best Test Result: 0.9020, Training Loss: 0.3011\n",
      "Total time elapsed: 141.5850s\n",
      "Optimization Finished! Best Test Result: 0.9216, Training Loss: 0.3015\n",
      "Total time elapsed: 150.8749s\n",
      "Optimization Finished! Best Test Result: 0.8039, Training Loss: 0.3447\n",
      "Total time elapsed: 158.6990s\n",
      "Optimization Finished! Best Test Result: 0.8627, Training Loss: 0.3399\n",
      "Total time elapsed: 167.8894s\n",
      "Optimization Finished! Best Test Result: 0.8235, Training Loss: 0.4043\n",
      "Total time elapsed: 178.1315s\n",
      "Test Mean: 0.8510, Test Std: 0.0563\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset wisconsin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "Optimization Finished! Best Test Result: 0.1988, Training Loss: 1.6082\n",
      "Total time elapsed: 162.0248s\n",
      "Optimization Finished! Best Test Result: 0.3631, Training Loss: 1.4354\n",
      "Total time elapsed: 312.5068s\n",
      "Optimization Finished! Best Test Result: 0.4611, Training Loss: 1.0656\n",
      "Total time elapsed: 544.0539s\n",
      "Optimization Finished! Best Test Result: 0.5255, Training Loss: 1.1182\n",
      "Total time elapsed: 920.1775s\n",
      "Optimization Finished! Best Test Result: 0.5427, Training Loss: 1.0917\n",
      "Total time elapsed: 1724.7475s\n",
      "Optimization Finished! Best Test Result: 0.5514, Training Loss: 1.1414\n",
      "Total time elapsed: 3194.2338s\n",
      "Optimization Finished! Best Test Result: 0.5850, Training Loss: 1.1522\n",
      "Total time elapsed: 5608.5630s\n",
      "Optimization Finished! Best Test Result: 0.6033, Training Loss: 1.1461\n",
      "Total time elapsed: 8657.9146s\n",
      "Optimization Finished! Best Test Result: 0.5812, Training Loss: 1.2398\n",
      "Total time elapsed: 10961.3591s\n",
      "Optimization Finished! Best Test Result: 0.6033, Training Loss: 1.3896\n",
      "Total time elapsed: 12691.9021s\n",
      "Test Mean: 0.5015, Test Std: 0.1230\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset squirrel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Optimization Finished! Best Test Result: 0.4668, Training Loss: 1.0656\n",
      "Total time elapsed: 25.1107s\n",
      "Optimization Finished! Best Test Result: 0.5775, Training Loss: 0.8516\n",
      "Total time elapsed: 67.3064s\n",
      "Optimization Finished! Best Test Result: 0.6660, Training Loss: 0.5889\n",
      "Total time elapsed: 161.7107s\n",
      "Optimization Finished! Best Test Result: 0.7143, Training Loss: 0.6832\n",
      "Total time elapsed: 354.3153s\n",
      "Optimization Finished! Best Test Result: 0.6881, Training Loss: 0.8226\n",
      "Total time elapsed: 584.7521s\n",
      "Optimization Finished! Best Test Result: 0.7384, Training Loss: 0.7834\n",
      "Total time elapsed: 1044.0145s\n",
      "Optimization Finished! Best Test Result: 0.6781, Training Loss: 0.9458\n",
      "Total time elapsed: 1494.8600s\n",
      "Optimization Finished! Best Test Result: 0.7082, Training Loss: 0.9054\n",
      "Total time elapsed: 2282.6014s\n",
      "Optimization Finished! Best Test Result: 0.7022, Training Loss: 0.9392\n",
      "Total time elapsed: 2976.1967s\n",
      "Optimization Finished! Best Test Result: 0.5875, Training Loss: 1.3750\n",
      "Total time elapsed: 3791.6863s\n",
      "Test Mean: 0.6527, Test Std: 0.0795\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset cora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "/home/will/Desktop/whentousegnn/utils.py:32: RuntimeWarning: divide by zero encountered in power\n",
      "  r_inv = np.power(rowsum, -1).flatten()\n",
      "Optimization Finished! Best Test Result: 0.6081, Training Loss: 0.6729\n",
      "Total time elapsed: 56.4084s\n",
      "Optimization Finished! Best Test Result: 0.6727, Training Loss: 0.4069\n",
      "Total time elapsed: 149.2332s\n",
      "Optimization Finished! Best Test Result: 0.7898, Training Loss: 0.3941\n",
      "Total time elapsed: 266.9952s\n",
      "Optimization Finished! Best Test Result: 0.7868, Training Loss: 0.4206\n",
      "Total time elapsed: 546.6572s\n",
      "Optimization Finished! Best Test Result: 0.8491, Training Loss: 0.3640\n",
      "Total time elapsed: 798.7956s\n",
      "Optimization Finished! Best Test Result: 0.8255, Training Loss: 0.4406\n",
      "Total time elapsed: 1298.6861s\n",
      "Optimization Finished! Best Test Result: 0.7087, Training Loss: 0.6044\n",
      "Total time elapsed: 1817.0769s\n",
      "Optimization Finished! Best Test Result: 0.7312, Training Loss: 0.7729\n",
      "Total time elapsed: 3103.1190s\n",
      "Optimization Finished! Best Test Result: 0.7778, Training Loss: 0.7549\n",
      "Total time elapsed: 4528.9549s\n",
      "Optimization Finished! Best Test Result: 0.7267, Training Loss: 0.9259\n",
      "Total time elapsed: 5733.4701s\n",
      "Test Mean: 0.7476, Test Std: 0.0692\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset citeseer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Optimization Finished! Best Test Result: 0.3211, Training Loss: 1.3919\n",
      "Total time elapsed: 243.1211s\n",
      "Optimization Finished! Best Test Result: 0.3974, Training Loss: 1.3061\n",
      "Total time elapsed: 600.6184s\n",
      "Optimization Finished! Best Test Result: 0.4066, Training Loss: 1.3105\n",
      "Total time elapsed: 2625.7850s\n",
      "Optimization Finished! Best Test Result: 0.3664, Training Loss: 1.3517\n",
      "Total time elapsed: 3435.1205s\n",
      "Optimization Finished! Best Test Result: 0.3921, Training Loss: 1.3218\n",
      "Total time elapsed: 6656.1858s\n",
      "Optimization Finished! Best Test Result: 0.4164, Training Loss: 1.2984\n",
      "Total time elapsed: 14281.7776s\n",
      "Optimization Finished! Best Test Result: 0.3829, Training Loss: 1.4224\n",
      "Total time elapsed: 20948.8995s\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset film"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n",
      "Optimization Finished! Best Test Result: 0.7027, Training Loss: 0.5227\n",
      "Total time elapsed: 50.1186s\n",
      "Optimization Finished! Best Test Result: 0.8919, Training Loss: 0.1340\n",
      "Total time elapsed: 57.6934s\n",
      "Optimization Finished! Best Test Result: 0.8378, Training Loss: 0.2231\n",
      "Total time elapsed: 66.1424s\n",
      "Optimization Finished! Best Test Result: 0.8919, Training Loss: 0.1467\n",
      "Total time elapsed: 74.3932s\n",
      "Optimization Finished! Best Test Result: 0.9189, Training Loss: 0.3145\n",
      "Total time elapsed: 80.5562s\n",
      "Optimization Finished! Best Test Result: 0.8649, Training Loss: 0.3979\n",
      "Total time elapsed: 86.4470s\n",
      "Optimization Finished! Best Test Result: 0.8649, Training Loss: 0.3206\n",
      "Total time elapsed: 95.8980s\n",
      "Optimization Finished! Best Test Result: 0.9189, Training Loss: 0.3900\n",
      "Total time elapsed: 158.5565s\n",
      "Optimization Finished! Best Test Result: 0.7568, Training Loss: 0.4470\n",
      "Total time elapsed: 165.3276s\n",
      "Optimization Finished! Best Test Result: 0.7297, Training Loss: 0.4333\n",
      "Total time elapsed: 171.2940s\n",
      "Test Mean: 0.8378, Test Std: 0.0755\n"
     ]
    }
   ],
   "source": [
    "!python train.py --dataset cornell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found existing installation: torch-cluster 1.5.8\n",
      "Uninstalling torch-cluster-1.5.8:\n",
      "  Would remove:\n",
      "    /home/will/.local/lib/python3.8/site-packages/torch_cluster-1.5.8.dist-info/*\n",
      "    /home/will/.local/lib/python3.8/site-packages/torch_cluster/*\n",
      "Proceed (y/n)? "
     ]
    }
   ],
   "source": [
    "!pip uninstall torch_cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Looking in links: https://pytorch-geometric.com/whl/torch-1.5.0+.html\n",
      "Requirement already satisfied: torch-cluster in /home/will/.local/lib/python3.8/site-packages (1.5.8)\n",
      "\u001b[33mWARNING: You are using pip version 20.1; however, version 20.3.3 is available.\n",
      "You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.5.0+${CUDA}.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "5788\n"
     ]
    }
   ],
   "source": [
    "import utils\n",
    "g, g_high, fea, la = utils.full_load_data('cornell')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3, 0, 2, 3, 4, 3, 0, 0, 3, 0, 3, 3, 3, 3, 3, 4, 3, 3, 0, 3, 0, 3, 3, 3,\n",
       "        3, 1, 3, 3, 0, 2, 3, 3, 4, 3, 4, 4, 2, 3, 3, 3, 0, 4, 0, 3, 3, 3, 2, 2,\n",
       "        0, 3, 0, 3, 3, 3, 3, 3, 0, 2, 2, 4, 4, 4, 3, 3, 3, 3, 0, 3, 3, 3, 4, 3,\n",
       "        3, 4, 4, 3, 0, 3, 0, 3, 4, 3, 2, 4, 2, 4, 3, 3, 0, 3, 3, 3, 0, 3, 3, 4,\n",
       "        3, 3, 3, 4, 0, 0, 4, 3, 3, 0, 4, 3, 2, 3, 3, 0, 3, 0, 0, 3, 3, 4, 3, 3,\n",
       "        3, 0, 3, 0, 3, 2, 4, 2, 3, 3, 0, 4, 3, 4, 3, 4, 3, 3, 0, 3, 2, 3, 3, 3,\n",
       "        3, 3, 3, 2, 3, 4, 3, 4, 3, 3, 3, 3, 0, 2, 0, 2, 3, 3, 4, 3, 0, 3, 3, 2,\n",
       "        0, 3, 3, 4, 3, 2, 0, 3, 0, 4, 3, 3, 4, 3, 3])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "la"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "pred = OneVsRestClassifier(LinearRegression()).fit(fea, la).predict(fea)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3, 0, 2, 3, 4, 3, 0, 0, 3, 0, 3, 3, 3, 3, 3, 4, 3, 3, 0, 3, 0, 3,\n",
       "       3, 3, 3, 1, 3, 3, 0, 2, 3, 3, 4, 3, 4, 4, 2, 3, 3, 3, 0, 4, 0, 3,\n",
       "       3, 3, 2, 2, 0, 3, 0, 3, 3, 3, 3, 3, 0, 2, 2, 4, 4, 4, 3, 3, 3, 3,\n",
       "       0, 3, 3, 3, 4, 3, 3, 4, 4, 3, 0, 3, 0, 3, 4, 3, 2, 4, 2, 4, 3, 3,\n",
       "       0, 3, 3, 3, 0, 3, 3, 4, 3, 3, 3, 4, 0, 0, 4, 3, 3, 0, 4, 3, 2, 3,\n",
       "       3, 0, 3, 0, 0, 3, 3, 4, 3, 3, 3, 0, 3, 0, 3, 2, 4, 2, 3, 3, 0, 4,\n",
       "       3, 4, 3, 4, 3, 3, 0, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 4, 3, 4, 3, 3,\n",
       "       3, 3, 0, 2, 0, 2, 3, 3, 4, 3, 0, 3, 3, 2, 0, 3, 3, 4, 3, 2, 0, 3,\n",
       "       0, 4, 3, 3, 4, 3, 3])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
