{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "hFwLefIsTwZL",
    "outputId": "2b72ce52-b6a6-4a94-c80e-ca06cc73411d"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_8332\\3843220337.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "dYuSfuanVdLy"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "rn5YNSFOog43",
    "outputId": "b97eb826-741f-4db0-cf9f-146633f74155"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_8332\\1859218086.py:7: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "1tCWvnpupR37"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "cKccKEapUqT4"
   },
   "outputs": [],
   "source": [
    "# from deeprobust.graph.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "l0NC0KhdT8JA"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for APPNPNET'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for APPNPNET\\\\CITESEER'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'CITESEER')\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Ch6kq6OxM8Ur",
    "outputId": "b945478f-0017-41c3-f38b-f2e43fc7bd50"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index\n",
      "Processing...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])\n",
      "torch.Size([3327, 3703]) torch.Size([3327, 3327])\n",
      "torch.Size([3327, 3703]) torch.Size([3327, 3327])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Done!\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "dataset= Planetoid(root=dataset, name='CITESEER')\n",
    "print(dataset[0])\n",
    "adj = to_dense_adj(dataset[0].edge_index)\n",
    "adj = adj[0]\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "3FQJfical415"
   },
   "outputs": [],
   "source": [
    "del dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "STsLjDMMN2bk",
    "outputId": "ecf8a181-5c73-4c83-be48-550a2ddbc05b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3327, 3327])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0])\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BxPj9tu-YZjX",
    "outputId": "a960d8d5-e0d2-42d4-c09a-0f7189803145"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6 3327\n"
     ]
    }
   ],
   "source": [
    "# dataset_name = 'flickr' \n",
    "\n",
    "# data = Dataset(root='', name=dataset_name, setting='gcn',seed=10)\n",
    "\n",
    "# adj, features, labels = data.adj, data.features, data.labels\n",
    "# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test\n",
    "\n",
    "# theta = csgraph.laplacian(adj).tocsr()\n",
    "features = X.numpy()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "# NO_OF_CLASSES =  7\n",
    "\n",
    "\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "YiN9k3_MueR-"
   },
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "Xla8XecUULkS"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.1)\n",
    "n = X.shape[1]\n",
    "lambda_param = 100\n",
    "beta_param = 50\n",
    "alpha_param = 100\n",
    "gamma_param = 100\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "X_tilde = random(k, n, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "rnKqrqAS9qmw"
   },
   "outputs": [],
   "source": [
    "from torch.nn import Linear\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import APPNP\n",
    "\n",
    "class APPNPNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(APPNPNet, self).__init__()\n",
    "        self.lin1 = Linear(X.shape[1], 128)\n",
    "        self.lin2 = Linear(128,NO_OF_CLASSES)\n",
    "        self.prop1 = APPNP(16, 0.1)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.lin1.reset_parameters()\n",
    "        self.lin2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.lin2(x)\n",
    "        x = self.prop1(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "StlALggCABGw"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "78DErAOL9vVT"
   },
   "outputs": [],
   "source": [
    "def get_accuracy(C_0,L):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1,2,3,4,5,6,7,8,9,10]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "       \n",
    "        from scipy import sparse\n",
    "     \n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    " \n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        \n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "      \n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "       \n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "     \n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "       \n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "       \n",
    "        Wc=Wc.toarray()\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        model= APPNPNet().to(device)\n",
    "        device = torch.device('cpu')\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        try:\n",
    "          X=np.array(features.todense())\n",
    "        except:\n",
    "          X = np.array(features)\n",
    "    \n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(200):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        Wc=sparse.csr_matrix(adj)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.1)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones)\n",
    "      ones = ones.to_dense()\n",
    "\n",
    "      J = np.outer(np.ones(k), np.ones(k))/k\n",
    "      J = csr_matrix(J)\n",
    "      J = convertScipyToTensor(J)\n",
    "      J = J.to_dense()\n",
    "\n",
    "      zeros = csr_matrix(np.zeros((p,k)))\n",
    "      zeros = convertScipyToTensor(zeros)\n",
    "      zeros = zeros.to_dense()\n",
    "\n",
    "      C = convertScipyToTensor(C)\n",
    "      C = C.to_dense()\n",
    "      eye = torch.eye(k)\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"yes\")\n",
    "        C = C.cuda()\n",
    "        theta = theta.cuda()\n",
    "        X = X.cuda()\n",
    "        J = J.cuda()\n",
    "        zeros = zeros.cuda()\n",
    "        ones = ones.cuda()\n",
    "        eye = eye.cuda()\n",
    "\n",
    "      def update(C,i):\n",
    "          global L\n",
    "          thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = CT@thetaC + J\n",
    "          term_bracket = torch.linalg.pinv(t1)\n",
    "          L = 1/k\n",
    "          t1 = -2*gamma_param*(thetaC@term_bracket)\n",
    "          t4 = lambda_param*(C@ones)\n",
    "          t5 = 2*beta_param*(thetaC@CT@thetaC)\n",
    "          T2=(t1+t4+t5)/L\n",
    "          Cnew = (C-T2).maximum(zeros)\n",
    "          Cnew[Cnew<thresh] = thresh\n",
    "          for i in range(len(Cnew)):\n",
    "              Cnew[i] = Cnew[i]/torch.linalg.norm(Cnew[i],1)\n",
    "          return Cnew\n",
    "\n",
    "      for i in tqdm(range(20)):   \n",
    "          C = update(C,i)\n",
    "    \n",
    "\n",
    "      return C\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ejPLG3L9zYrc",
    "outputId": "57c7f0b2-f8e2-4806-89ec-b6afc8c204ca"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00,  1.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7759\n",
      "Epoch: 100,loss: 0.0746\n",
      "Accuracy = 0.2131048993086865 100 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7710\n",
      "Epoch: 100,loss: 0.0549\n",
      "Average accuracy = 21.10009017132552 +/- 0.21039975954313256\n",
      "Params =  1 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7845\n",
      "Epoch: 100,loss: 0.1278\n",
      "Accuracy = 0.36188758641418695 100 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7273\n",
      "Epoch: 100,loss: 0.1569\n",
      "Accuracy = 0.545536519386835 100 0.01 10\n",
      "Average accuracy = 45.371205290051094 +/- 9.182446648632403\n",
      "Params =  1 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7884\n",
      "Epoch: 100,loss: 0.8982\n",
      "Accuracy = 0.6149684400360685 100 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7946\n",
      "Epoch: 100,loss: 0.8802\n",
      "Accuracy = 0.6429215509467989 100 0.01 100\n",
      "Average accuracy = 62.89449954914337 +/- 1.3976555455365192\n",
      "Params =  1 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8188\n",
      "Epoch: 100,loss: 0.1098\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7785\n",
      "Epoch: 100,loss: 0.0759\n",
      "Average accuracy = 24.241058010219415 +/- 0.5560565073639917\n",
      "Params =  1 0.1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7583\n",
      "Epoch: 100,loss: 0.0674\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8516\n",
      "Epoch: 100,loss: 0.2214\n",
      "Average accuracy = 24.81214307183649 +/- 3.6218815749924858\n",
      "Params =  1 0.1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7964\n",
      "Epoch: 100,loss: 0.8929\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8088\n",
      "Epoch: 100,loss: 0.9018\n",
      "Average accuracy = 63.30027051397656 +/- 0.6612563871355559\n",
      "Params =  1 0.1 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8582\n",
      "Epoch: 100,loss: 0.0509\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7496\n",
      "Epoch: 100,loss: 0.0381\n",
      "Average accuracy = 26.345055605650735 +/- 0.7363991584009633\n",
      "Params =  1 1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8322\n",
      "Epoch: 100,loss: 0.0494\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7564\n",
      "Epoch: 100,loss: 0.0494\n",
      "Average accuracy = 22.076946197775772 +/- 0.7063420498948009\n",
      "Params =  1 1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8018\n",
      "Epoch: 100,loss: 0.7831\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7961\n",
      "Epoch: 100,loss: 0.7386\n",
      "Accuracy = 0.6672678088367899 100 1 100\n",
      "Average accuracy = 65.43432521791404 +/- 1.2924556657649577\n",
      "Params =  1 1 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7217\n",
      "Epoch: 100,loss: 0.0003\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7756\n",
      "Epoch: 100,loss: 0.0002\n",
      "Average accuracy = 21.265404268109407 +/- 0.045085662759242195\n",
      "Params =  10 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7788\n",
      "Epoch: 100,loss: 0.1073\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8220\n",
      "Epoch: 100,loss: 0.1091\n",
      "Average accuracy = 64.9233543733093 +/- 0.8115419296663651\n",
      "Params =  10 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7885\n",
      "Epoch: 100,loss: 0.4827\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7916\n",
      "Epoch: 100,loss: 0.5722\n",
      "Average accuracy = 63.45055605650736 +/- 0.06011421701232478\n",
      "Params =  10 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7854\n",
      "Epoch: 100,loss: 0.0002\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7454\n",
      "Epoch: 100,loss: 0.0004\n",
      "Average accuracy = 21.220318605350165 +/- 0.06011421701232339\n",
      "Params =  10 0.1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8373\n",
      "Epoch: 100,loss: 0.0514\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7834\n",
      "Epoch: 100,loss: 0.0645\n",
      "Average accuracy = 21.911632100991884 +/- 0.06011421701232339\n",
      "Params =  10 0.1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7963\n",
      "Epoch: 100,loss: 0.0614\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7387\n",
      "Epoch: 100,loss: 0.0938\n",
      "Average accuracy = 23.549744514577696 +/- 2.780282536819957\n",
      "Params =  10 0.1 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8221\n",
      "Epoch: 100,loss: 0.0925\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8064\n",
      "Epoch: 100,loss: 0.0008\n",
      "Average accuracy = 24.406372107003307 +/- 3.787195671776375\n",
      "Params =  10 1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8044\n",
      "Epoch: 100,loss: 0.0785\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7984\n",
      "Epoch: 100,loss: 0.0605\n",
      "Average accuracy = 20.544033663961528 +/- 0.07514277126540458\n",
      "Params =  10 1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8620\n",
      "Epoch: 100,loss: 0.0854\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00,  2.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8371\n",
      "Epoch: 100,loss: 0.0745\n",
      "Average accuracy = 26.55545536519387 +/- 2.9005109708446053\n",
      "Params =  10 1 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8355\n",
      "Epoch: 100,loss: 0.0004\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7866\n",
      "Epoch: 100,loss: 0.0013\n",
      "Average accuracy = 21.160204388337842 +/- 0.03005710850616239\n",
      "Params =  100 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7485\n",
      "Epoch: 100,loss: 0.0002\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7262\n",
      "Epoch: 100,loss: 0.0003\n",
      "Average accuracy = 21.29546137661557 +/- 0.16531409678389036\n",
      "Params =  100 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7936\n",
      "Epoch: 100,loss: 0.1431\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7953\n",
      "Epoch: 100,loss: 0.1533\n",
      "Average accuracy = 63.3453561767358 +/- 0.916741809437932\n",
      "Params =  100 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7811\n",
      "Epoch: 100,loss: 0.0004\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8620\n",
      "Epoch: 100,loss: 0.0009\n",
      "Average accuracy = 21.220318605350165 +/- 0.27051397655545456\n",
      "Params =  100 0.1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7888\n",
      "Epoch: 100,loss: 0.0003\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  4.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7459\n",
      "Epoch: 100,loss: 0.0003\n",
      "Average accuracy = 20.94980462879471 +/- 0.21039975954313117\n",
      "Params =  100 0.1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7944\n",
      "Epoch: 100,loss: 0.1335\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00,  2.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7655\n",
      "Epoch: 100,loss: 0.1727\n",
      "Average accuracy = 65.19386834986474 +/- 0.18034265103697433\n",
      "Params =  100 0.1 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8487\n",
      "Epoch: 100,loss: 0.0004\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8113\n",
      "Epoch: 100,loss: 0.0005\n",
      "Average accuracy = 21.024947400060114 +/- 0.13525698827772797\n",
      "Params =  100 1 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8100\n",
      "Epoch: 100,loss: 0.0004\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6922\n",
      "Epoch: 100,loss: 0.0003\n",
      "Average accuracy = 21.024947400060114 +/- 0.22542831379621237\n",
      "Params =  100 1 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.8157\n",
      "Epoch: 100,loss: 0.0652\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.7563\n",
      "Epoch: 100,loss: 0.1058\n",
      "Average accuracy = 22.197174631800422 +/- 2.5097685602645026\n",
      "Params =  100 1 100\n"
     ]
    }
   ],
   "source": [
    "\n",
    "highest_accuracy=0\n",
    "for lambda_param in [1,10,100]:\n",
    "  for beta_param in [0.01,0.1,1]:\n",
    "      for gamma_param in [1,10,100]:\n",
    "\n",
    "        av = []\n",
    "\n",
    "        for _ in range(2):\n",
    "            avg_accuracy_all=[]\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "              C_0 = experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "              L = theta\n",
    "              C_0 = C_0.cpu().detach().numpy()\n",
    "              C_t_0 = C_0.T\n",
    "              try:\n",
    "                L = L.cpu().detach().numpy()\n",
    "              except:\n",
    "                L = L\n",
    "\n",
    "              acc = get_accuracy(C_0,L)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(lambda_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ERROR: Could not find a version that satisfies the requirement stellafraph (from versions: none)\n",
      "ERROR: No matching distribution found for stellafraph\n"
     ]
    }
   ],
   "source": [
    "!pip install stellafraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'Planetoid' object has no attribute 'load'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[1;32mIn [22], line 3\u001b[0m\n\u001b[0;32m      1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(os\u001b[38;5;241m.\u001b[39mgetcwd(),\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCITESEER\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m      2\u001b[0m dataset\u001b[38;5;241m=\u001b[39m Planetoid(root\u001b[38;5;241m=\u001b[39mdataset, name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCITESEER\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 3\u001b[0m G, node_subjects \u001b[38;5;241m=\u001b[39m \u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m()\n\u001b[0;32m      5\u001b[0m row \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(Wc\u001b[38;5;241m.\u001b[39mrow)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mlong)\n\u001b[0;32m      6\u001b[0m col \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfrom_numpy(Wc\u001b[38;5;241m.\u001b[39mcol)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mlong)\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'Planetoid' object has no attribute 'load'"
     ]
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'CITESEER')\n",
    "dataset= Planetoid(root=dataset, name='CITESEER')\n",
    "G, node_subjects = dataset.load()\n",
    "\n",
    "row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "edge_weight = torch.from_numpy(Wc.data)\n",
    "node_embeddings =model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_subject = node_subjects.astype(\"category\").cat.codes\n",
    "\n",
    "X = node_embeddings\n",
    "if X.shape[1] > 2:\n",
    "    transform = TSNE  # PCA\n",
    "\n",
    "    trans = transform(n_components=2)\n",
    "    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=node_ids)\n",
    "    emb_transformed[\"label\"] = node_subject\n",
    "else:\n",
    "    emb_transformed = pd.DataFrame(X, index=node_ids)\n",
    "    emb_transformed = emb_transformed.rename(columns={\"0\": 0, \"1\": 1})\n",
    "    emb_transformed[\"label\"] = node_subject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.7\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 7))\n",
    "ax.scatter(\n",
    "    emb_transformed[0],\n",
    "    emb_transformed[1],\n",
    "    c=emb_transformed[\"label\"].astype(\"category\"),\n",
    "    cmap=\"jet\",\n",
    "    alpha=alpha,\n",
    ")\n",
    "ax.set(aspect=\"equal\", xlabel=\"$X_1$\", ylabel=\"$X_2$\")\n",
    "plt.title(\n",
    "    \"{} visualization of GraphSAGE embeddings for cora dataset\".format(transform.__name__)\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6672678088367899"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "highest_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "5c1f741a4f83aa020b4b2a4d7353a073a4e5e4a855a3258a20da40294ddbf005"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
