{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "hFwLefIsTwZL",
    "outputId": "2b72ce52-b6a6-4a94-c80e-ca06cc73411d"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_9964\\3843220337.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "dYuSfuanVdLy"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "rn5YNSFOog43",
    "outputId": "b97eb826-741f-4db0-cf9f-146633f74155"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_9964\\1859218086.py:7: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "1tCWvnpupR37"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "cKccKEapUqT4"
   },
   "outputs": [],
   "source": [
    "# from deeprobust.graph.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "l0NC0KhdT8JA"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for APPNPNET'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment for Sparsity\\\\Exp for APPNPNET\\\\CoauhorPhysics'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'CoauhorPhysics')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Ch6kq6OxM8Ur",
    "outputId": "b945478f-0017-41c3-f38b-f2e43fc7bd50"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/shchur/gnn-benchmark/raw/master/data/npz/ms_academic_phy.npz\n",
      "Processing...\n",
      "Done!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[34493, 8415], edge_index=[2, 495924], y=[34493])\n",
      "torch.Size([34493, 8415]) torch.Size([34493, 34493])\n",
      "torch.Size([34493, 8415]) torch.Size([34493, 34493])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Coauthor,Planetoid,CitationFull,WebKB\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# dataset = Planetoid(root='/pubmed',name='PubMed') #change for Cora, Citeseer, PubMed\n",
    "# dataset = CitationFull(root='/dblp',name='DBLP')\n",
    "dataset = Coauthor(root='Physics',name='Physics') # change for CS, Physics\n",
    "print(dataset[0])\n",
    "\n",
    "edge_index = dataset[0].edge_index\n",
    "\n",
    "# edge_index, added_edges = add_random_edge(edge_index, p=0.2,force_undirected=True)\n",
    "\n",
    "adj = to_dense_adj(edge_index)\n",
    "adj = adj[0]\n",
    "\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "STsLjDMMN2bk",
    "outputId": "ecf8a181-5c73-4c83-be48-550a2ddbc05b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34493, 34493])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0])\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BxPj9tu-YZjX",
    "outputId": "a960d8d5-e0d2-42d4-c09a-0f7189803145"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 34493\n"
     ]
    }
   ],
   "source": [
    "# dataset_name = 'flickr' \n",
    "\n",
    "# data = Dataset(root='', name=dataset_name, setting='gcn',seed=10)\n",
    "\n",
    "# adj, features, labels = data.adj, data.features, data.labels\n",
    "# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test\n",
    "\n",
    "# theta = csgraph.laplacian(adj).tocsr()\n",
    "features = X.numpy()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "# NO_OF_CLASSES =  7\n",
    "\n",
    "\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "YiN9k3_MueR-"
   },
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "Xla8XecUULkS"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.1)\n",
    "n = X.shape[1]\n",
    "lambda_param = 100\n",
    "beta_param = 50\n",
    "alpha_param = 100\n",
    "gamma_param = 100\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "X_tilde = random(k, n, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "rnKqrqAS9qmw"
   },
   "outputs": [],
   "source": [
    "from torch.nn import Linear\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import APPNP\n",
    "\n",
    "class APPNPNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(APPNPNet, self).__init__()\n",
    "        self.lin1 = Linear(X.shape[1], 128)\n",
    "        self.lin2 = Linear(128,NO_OF_CLASSES)\n",
    "        self.prop1 = APPNP(16, 0.1)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.lin1.reset_parameters()\n",
    "        self.lin2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.lin2(x)\n",
    "        x = self.prop1(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "StlALggCABGw"
   },
   "outputs": [],
   "source": [
    "from random import sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "78DErAOL9vVT"
   },
   "outputs": [],
   "source": [
    "def get_accuracy(C_0,L):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1,2,3,4,5,6,7,8,9,10]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        # print(C_0_new)\n",
    "        # C_0_new=C_0\n",
    "        from scipy import sparse\n",
    "        #Lc=C_0.T@L@C_0\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        # print(\"L:\", Lc.shape)\n",
    "        # Lc=L_new\n",
    "        #print(Lc)\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        # print(\"W:\", Wc.shape)\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        #print(\"edgecoarsen:\", edge_index_coarsen2.shape)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        #print(\"edgeweight:\", edge_weight.shape)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        #print(\"Y:\", Y.shape)\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        # NO_OF_CLASSES=Y.shape[1]\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        #print(\"Lables:\", labels_coarse.shape)\n",
    "\n",
    "        #torch.Tensor(C2)@X\n",
    "        Wc=Wc.toarray()\n",
    "        #Wc[Wc<0.01]=0\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        model=APPNPNet().to(device)\n",
    "        device = torch.device('cpu')\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        try:\n",
    "          X=np.array(features.todense())\n",
    "        except:\n",
    "          X = np.array(features)\n",
    "        #print(\"X:\",X.shape)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        # criterion=torch.nn.CrossEntropyLoss()\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        # Xt=X_t_0\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(60):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        Wc=sparse.csr_matrix(adj)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.1)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones)\n",
    "      ones = ones.to_dense()\n",
    "      J = np.outer(np.ones(k), np.ones(k))/k\n",
    "      J = csr_matrix(J)\n",
    "      J = convertScipyToTensor(J)\n",
    "      J = J.to_dense()\n",
    "      zeros = csr_matrix(np.zeros((p,k)))\n",
    "      zeros = convertScipyToTensor(zeros)\n",
    "      zeros = zeros.to_dense()\n",
    "      C = convertScipyToTensor(C)\n",
    "      C = C.to_dense()\n",
    "      eye = torch.eye(k)\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "\n",
    "      # if(torch.cuda.is_available()):\n",
    "      #   print(\"yes\")\n",
    "      #   C = C.cuda()\n",
    "      #   theta = theta.cuda()\n",
    "      #   X = X.cuda()\n",
    "      #   J = J.cuda()\n",
    "      #   zeros = zeros.cuda()\n",
    "      #   ones = ones.cuda()\n",
    "      #   eye = eye.cuda()\n",
    "\n",
    "      def update(C,i):\n",
    "          global L\n",
    "          thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = CT@thetaC + J\n",
    "          term_bracket = torch.linalg.pinv(t1)         \n",
    "          L = 1/k\n",
    "          t1 = -2*gamma_param*(thetaC@term_bracket)\n",
    "          t4 = lambda_param*(C@ones)\n",
    "          t5 = 2*beta_param*(thetaC@CT@thetaC)\n",
    "          T2=(t1+t4+t5)/L\n",
    "          Cnew = (C-T2).maximum(zeros)\n",
    "          Cnew[Cnew<thresh] = thresh\n",
    "          for i in range(len(Cnew)):\n",
    "              Cnew[i] = Cnew[i]/torch.linalg.norm(Cnew[i],1)\n",
    "          return Cnew\n",
    "    \n",
    "      for i in tqdm(range(10)):   #update C only 21\n",
    "          C = update(C,i)\n",
    "    \n",
    "\n",
    "      return C\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ejPLG3L9zYrc",
    "outputId": "57c7f0b2-f8e2-4806-89ec-b6afc8c204ca"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [05:47<00:00, 34.79s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6766\n",
      "Accuracy = 0.528802945525179 100 0.01 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [05:44<00:00, 34.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6053\n",
      "Average accuracy = 50.185544893166735 +/- 2.6947496593511726\n",
      "Params =  100 0.01 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [03:57<00:00, 23.76s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5976\n",
      "Accuracy = 0.9399298408372714 100 0.01 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [03:56<00:00, 23.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6506\n",
      "Average accuracy = 93.88136723393153 +/- 0.11161684979561493\n",
      "Params =  100 0.01 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:44<00:00, 28.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5793\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:45<00:00, 28.53s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6160\n",
      "Average accuracy = 48.401124865914824 +/- 6.531760067260025\n",
      "Params =  100 0.01 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:45<00:00, 28.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5831\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:46<00:00, 28.62s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5746\n",
      "Average accuracy = 54.93143536369698 +/- 0.9523671469573558\n",
      "Params =  100 0.1 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:08<00:00, 24.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6850\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:26<00:00, 26.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5776\n",
      "Average accuracy = 50.505899747774905 +/- 0.17394833734381\n",
      "Params =  100 0.1 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:45<00:00, 28.60s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5960\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:46<00:00, 28.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5509\n",
      "Average accuracy = 51.659757052155506 +/- 3.7022004464673963\n",
      "Params =  100 0.1 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:00<00:00, 24.00s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [03:59<00:00, 23.97s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5440\n",
      "Average accuracy = 53.25718261676282 +/- 2.249731829646595\n",
      "Params =  10 0.01 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:03<00:00, 24.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5588\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:03<00:00, 24.36s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 135812464.0000\n",
      "Average accuracy = 57.80883077725916 +/- 34.2533267619517\n",
      "Params =  10 0.01 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:47<00:00, 28.72s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6417\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:46<00:00, 28.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6015\n",
      "Average accuracy = 50.89003565940916 +/- 0.9871568144261184\n",
      "Params =  10 0.01 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [03:57<00:00, 23.72s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 459158144.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [03:59<00:00, 23.91s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.6876\n",
      "Average accuracy = 41.88820920186704 +/- 8.446641347519787\n",
      "Params =  10 0.1 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [04:08<00:00, 24.81s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.5225\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|█████████████████████████████████████████████████▊                                 | 6/10 [02:44<01:49, 27.41s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn [20], line 12\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m):\n\u001b[0;32m     11\u001b[0m   C \u001b[38;5;241m=\u001b[39m random(p, k, density\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.15\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, data_rvs\u001b[38;5;241m=\u001b[39mtemp2\u001b[38;5;241m.\u001b[39mrvs)\n\u001b[1;32m---> 12\u001b[0m   C_0 \u001b[38;5;241m=\u001b[39m \u001b[43mexperiment_sparsity\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlambda_param\u001b[49m\u001b[43m,\u001b[49m\u001b[43mbeta_param\u001b[49m\u001b[43m,\u001b[49m\u001b[43mgamma_param\u001b[49m\u001b[43m,\u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtheta\u001b[49m\u001b[43m,\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     13\u001b[0m   L \u001b[38;5;241m=\u001b[39m theta\n\u001b[0;32m     14\u001b[0m   C_0 \u001b[38;5;241m=\u001b[39m C_0\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()\n",
      "Cell \u001b[1;32mIn [17], line 56\u001b[0m, in \u001b[0;36mexperiment_sparsity\u001b[1;34m(lambda_param, beta_param, gamma_param, C, theta, X)\u001b[0m\n\u001b[0;32m     53\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m Cnew\n\u001b[0;32m     55\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m10\u001b[39m)):   \u001b[38;5;66;03m#update C only 21\u001b[39;00m\n\u001b[1;32m---> 56\u001b[0m     C \u001b[38;5;241m=\u001b[39m \u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m C\n",
      "Cell \u001b[1;32mIn [17], line 40\u001b[0m, in \u001b[0;36mexperiment_sparsity.<locals>.update\u001b[1;34m(C, i)\u001b[0m\n\u001b[0;32m     38\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mupdate\u001b[39m(C,i):\n\u001b[0;32m     39\u001b[0m     \u001b[38;5;28;01mglobal\u001b[39;00m L\n\u001b[1;32m---> 40\u001b[0m     thetaC \u001b[38;5;241m=\u001b[39m \u001b[43mtheta\u001b[49m\u001b[38;5;129;43m@C\u001b[39;49m\n\u001b[0;32m     41\u001b[0m     CT \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtranspose(C,\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m     42\u001b[0m     t1 \u001b[38;5;241m=\u001b[39m CT\u001b[38;5;129m@thetaC\u001b[39m \u001b[38;5;241m+\u001b[39m J\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "\n",
    "highest_accuracy=0\n",
    "for lambda_param in [100,10,1]:\n",
    "  for beta_param in [0.01,0.1]:\n",
    "      for gamma_param in [10,100,1]:\n",
    "\n",
    "        av = []\n",
    "\n",
    "        for _ in range(2):\n",
    "            avg_accuracy_all=[]\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "              C_0 = experiment_sparsity(lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "              L = theta\n",
    "              C_0 = C_0.cpu().detach().numpy()\n",
    "              C_t_0 = C_0.T\n",
    "              try:\n",
    "                L = L.cpu().detach().numpy()\n",
    "              except:\n",
    "                L = L\n",
    "\n",
    "              acc = get_accuracy(C_0,L)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(lambda_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9399298408372714"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "highest_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "5c1f741a4f83aa020b4b2a4d7353a073a4e5e4a855a3258a20da40294ddbf005"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
