{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "hFwLefIsTwZL",
    "outputId": "2b72ce52-b6a6-4a94-c80e-ca06cc73411d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_15960\\3843220337.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "dYuSfuanVdLy"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "rn5YNSFOog43",
    "outputId": "b97eb826-741f-4db0-cf9f-146633f74155"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_15960\\1859218086.py:7: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "1tCWvnpupR37"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "id": "cKccKEapUqT4"
   },
   "outputs": [],
   "source": [
    "# from deeprobust.graph.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "id": "l0NC0KhdT8JA"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'c:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for  GAT'"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'c:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for  GAT\\\\CoauhorCS'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'CoauhorCS')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Ch6kq6OxM8Ur",
    "outputId": "b945478f-0017-41c3-f38b-f2e43fc7bd50"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[18333, 6805], edge_index=[2, 163788], y=[18333])\n",
      "torch.Size([18333, 6805]) torch.Size([18333, 18333])\n",
      "torch.Size([18333, 6805]) torch.Size([18333, 18333])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Coauthor,Planetoid,CitationFull,WebKB\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# dataset = Planetoid(root='/pubmed',name='PubMed') #change for Cora, Citeseer, PubMed\n",
    "# dataset = CitationFull(root='/dblp',name='DBLP')\n",
    "dataset = Coauthor(root='CS',name='CS') # change for CS, Physics\n",
    "print(dataset[0])\n",
    "\n",
    "edge_index = dataset[0].edge_index\n",
    "\n",
    "# edge_index, added_edges = add_random_edge(edge_index, p=0.2,force_undirected=True)\n",
    "\n",
    "adj = to_dense_adj(edge_index)\n",
    "adj = adj[0]\n",
    "\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "STsLjDMMN2bk",
    "outputId": "ecf8a181-5c73-4c83-be48-550a2ddbc05b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18333, 18333])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0])\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BxPj9tu-YZjX",
    "outputId": "a960d8d5-e0d2-42d4-c09a-0f7189803145"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15 18333\n"
     ]
    }
   ],
   "source": [
    "# dataset_name = 'flickr' \n",
    "\n",
    "# data = Dataset(root='', name=dataset_name, setting='gcn',seed=10)\n",
    "\n",
    "# adj, features, labels = data.adj, data.features, data.labels\n",
    "# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test\n",
    "\n",
    "# theta = csgraph.laplacian(adj).tocsr()\n",
    "features = X.numpy()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "# NO_OF_CLASSES =  7\n",
    "\n",
    "\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "id": "YiN9k3_MueR-"
   },
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "id": "Xla8XecUULkS"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.1)\n",
    "n = X.shape[1]\n",
    "lambda_param = 100\n",
    "beta_param = 50\n",
    "alpha_param = 100\n",
    "gamma_param = 100\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "X_tilde = random(k, n, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "rnKqrqAS9qmw"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GATConv\n",
    "class GAT(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GAT, self).__init__()\n",
    "        self.hid = 64\n",
    "        self.in_head = 64\n",
    "        self.out_head = 32\n",
    "        \n",
    "        self.conv1 = GATConv(X.shape[1], self.hid, heads=self.in_head, dropout=0.2)\n",
    "        self.conv2 = GATConv(self.hid*self.in_head, NO_OF_CLASSES, concat=False, heads=self.out_head, dropout=0.2)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.conv1.reset_parameters()\n",
    "        self.conv2.reset_parameters()\n",
    "\n",
    "\n",
    "    def forward(self, x,edge_index):\n",
    "        \n",
    "        x = F.dropout(x, p=0.2, training=self.training)\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.elu(x)\n",
    "        x = F.dropout(x, p=0.2, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        \n",
    "        return F.log_softmax(x, dim=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "StlALggCABGw"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "78DErAOL9vVT"
   },
   "outputs": [],
   "source": [
    "def get_accuracy(C_0,L):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1,2,3,4,5,6,7,8,9,10]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        # print(C_0_new)\n",
    "        # C_0_new=C_0\n",
    "        from scipy import sparse\n",
    "        #Lc=C_0.T@L@C_0\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        # print(\"L:\", Lc.shape)\n",
    "        # Lc=L_new\n",
    "        #print(Lc)\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        # print(\"W:\", Wc.shape)\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        #print(\"edgecoarsen:\", edge_index_coarsen2.shape)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        #print(\"edgeweight:\", edge_weight.shape)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        #print(\"Y:\", Y.shape)\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        # NO_OF_CLASSES=Y.shape[1]\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        #print(\"Lables:\", labels_coarse.shape)\n",
    "\n",
    "        #torch.Tensor(C2)@X\n",
    "        Wc=Wc.toarray()\n",
    "        #Wc[Wc<0.01]=0\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        model=GAT().to(device)\n",
    "        device = torch.device('cpu')\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        try:\n",
    "          X=np.array(features.todense())\n",
    "        except:\n",
    "          X = np.array(features)\n",
    "        #print(\"X:\",X.shape)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        # criterion=torch.nn.CrossEntropyLoss()\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        # Xt=X_t_0\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(60):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        Wc=sparse.csr_matrix(adj)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.1)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones)\n",
    "      ones = ones.to_dense()\n",
    "      J = np.outer(np.ones(k), np.ones(k))/k\n",
    "      J = csr_matrix(J)\n",
    "      J = convertScipyToTensor(J)\n",
    "      J = J.to_dense()\n",
    "      zeros = csr_matrix(np.zeros((p,k)))\n",
    "      zeros = convertScipyToTensor(zeros)\n",
    "      zeros = zeros.to_dense()\n",
    "      C = convertScipyToTensor(C)\n",
    "      C = C.to_dense()\n",
    "      eye = torch.eye(k)\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"yes\")\n",
    "        C = C.cuda()\n",
    "        theta = theta.cuda()\n",
    "        X = X.cuda()\n",
    "        J = J.cuda()\n",
    "        zeros = zeros.cuda()\n",
    "        ones = ones.cuda()\n",
    "        eye = eye.cuda()\n",
    "\n",
    "      def update(C,i):\n",
    "          global L\n",
    "          thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = CT@thetaC + J\n",
    "          term_bracket = torch.linalg.pinv(t1)         \n",
    "          L = 1/k\n",
    "          t1 = -2*gamma_param*(thetaC@term_bracket)\n",
    "          t4 = lambda_param*(C@ones)\n",
    "          t5 = 2*beta_param*(thetaC@CT@thetaC)\n",
    "          T2=(t1+t4+t5)/L\n",
    "          Cnew = (C-T2).maximum(zeros)\n",
    "          Cnew[Cnew<thresh] = thresh\n",
    "          for i in range(len(Cnew)):\n",
    "              Cnew[i] = Cnew[i]/torch.linalg.norm(Cnew[i],1)\n",
    "          return Cnew\n",
    "    \n",
    "      for i in tqdm(range(20)):   #update C only 21\n",
    "          C = update(C,i)\n",
    "    \n",
    "\n",
    "      return C\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ejPLG3L9zYrc",
    "outputId": "57c7f0b2-f8e2-4806-89ec-b6afc8c204ca"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:45<00:00,  2.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "Accuracy = 0.24720449462717503 100 0.001 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:45<00:00,  2.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.147711776577758 +/- 0.5727376861397474\n",
      "Params =  100 0.001 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:44<00:00,  2.23s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:41<00:00,  2.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "Average accuracy = 24.19680357824688 +/- 0.46909943817160377\n",
      "Params =  100 0.001 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:48<00:00,  2.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7073\n",
      "Accuracy = 0.8852342769868543 100 0.001 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:50<00:00,  2.53s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7047\n",
      "Average accuracy = 88.32433316969399 +/- 0.19909452899143365\n",
      "Params =  100 0.001 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:42<00:00,  2.11s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:41<00:00,  2.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.933180603283695 +/- 0.005454644629902805\n",
      "Params =  100 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:48<00:00,  2.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:00<00:00,  3.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.930453280968745 +/- 0.15000272732231396\n",
      "Params =  100 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:49<00:00,  2.50s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7079\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:40<00:00,  2.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7078\n",
      "Average accuracy = 86.11247477226857 +/- 0.5454644629902361\n",
      "Params =  100 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:35<00:00,  1.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:33<00:00,  1.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.507718322151312 +/- 0.39273441335297005\n",
      "Params =  100 0.0 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:33<00:00,  1.66s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:37<00:00,  1.86s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 23.85588828887798 +/- 0.13909343806250973\n",
      "Params =  100 0.0 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:34<00:00,  1.72s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:40<00:00,  2.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 3.940980745104457 +/- 0.07363770250368197\n",
      "Params =  100 0.0 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:23<00:00,  4.20s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7087\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:10<00:00,  3.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7065\n",
      "Average accuracy = 86.2515682103311 +/- 1.0063819342169822\n",
      "Params =  10 0.001 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:54<00:00,  5.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:09<00:00,  3.45s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.905907380134185 +/- 0.10909289259804639\n",
      "Params =  10 0.001 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:48<00:00,  2.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7082\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:52<00:00,  5.62s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 26.340478917798503 +/- 10.494736267932145\n",
      "Params =  10 0.001 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [02:20<00:00,  7.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:52<00:00,  2.61s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 24.927725958653795 +/- 0.04363715703921828\n",
      "Params =  10 0.01 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:50<00:00,  2.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:41<00:00,  2.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "Average accuracy = 24.425898652702777 +/- 0.07636502481863372\n",
      "Params =  10 0.01 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:02<00:00,  3.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:52<00:00,  2.61s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1170961024.0000\n",
      "Average accuracy = 10.707467408498337 +/- 3.4418807614683904\n",
      "Params =  10 0.01 100\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:41<00:00,  2.07s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:34<00:00,  1.72s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "Average accuracy = 10.137457044673539 +/- 6.275568646702667\n",
      "Params =  10 0.0 10\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:40<00:00,  2.04s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7080\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:44<00:00,  2.24s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "Average accuracy = 24.417716685757924 +/- 0.0681830578737802\n",
      "Params =  10 0.0 1\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [00:48<00:00,  2.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7078\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [01:48<00:00,  5.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 2.7081\n",
      "Average accuracy = 32.37877052310042 +/- 6.621938580701469\n",
      "Params =  10 0.0 100\n"
     ]
    }
   ],
   "source": [
    "\n",
    "highest_accuracy=0\n",
    "for lambda_param in [100,10]:\n",
    "  for beta_param in [0.001,0.01,0.,]:\n",
    "      for gamma_param in [10,1,100]:\n",
    "\n",
    "        av = []\n",
    "\n",
    "        for _ in range(2):\n",
    "            avg_accuracy_all=[]\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "              C_0 = experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "              L = theta\n",
    "              C_0 = C_0.cpu().detach().numpy()\n",
    "              C_t_0 = C_0.T\n",
    "              try:\n",
    "                L = L.cpu().detach().numpy()\n",
    "              except:\n",
    "                L = L\n",
    "\n",
    "              acc = get_accuracy(C_0,L)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(lambda_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8852342769868543"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "highest_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "5c1f741a4f83aa020b4b2a4d7353a073a4e5e4a855a3258a20da40294ddbf005"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
