{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3f80e3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_26160\\273441208.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c287cb5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_26160\\4057615484.py:25: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'c:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment of K-component\\\\Exp k-comp GCN'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "from random import sample\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv\n",
    "\n",
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3669a0f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'c:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment of K-component\\\\Exp k-comp GCN\\\\PubMed'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'PubMed')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "067c3d06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset: PubMed\n",
      "num_nodes: 19717\n",
      "num_edges: 88648\n",
      "num_classes: 3\n",
      "num_features: 500\n"
     ]
    }
   ],
   "source": [
    "import os.path as osp\n",
    "import torch\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def get_planetoid_dataset(name, normalize_features=False, transform=None, split=\"public\"):\n",
    "    path = osp.join(osp.dirname(osp.realpath(os.getcwd())), '..', 'data', name)\n",
    "    if split == 'complete':\n",
    "        dataset = Planetoid(path, name)\n",
    "        dataset[0].train_mask.fill_(False)\n",
    "        dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1\n",
    "        dataset[0].val_mask.fill_(False)\n",
    "        dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1\n",
    "        dataset[0].test_mask.fill_(False)\n",
    "        dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1\n",
    "    else:\n",
    "        dataset = Planetoid(path, name, split=split)\n",
    "    if transform is not None and normalize_features:\n",
    "        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])\n",
    "    elif normalize_features:\n",
    "        dataset.transform = T.NormalizeFeatures()\n",
    "    elif transform is not None:\n",
    "        dataset.transform = transform\n",
    "    return dataset\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "#     lst_names = ['Cora', 'CiteSeer', 'PubMed']\n",
    "    lst_names = ['PubMed']\n",
    "    for name in lst_names:\n",
    "        dataset = get_planetoid_dataset(name)\n",
    "        print(f\"dataset: {name}\")\n",
    "        print(f\"num_nodes: {dataset[0]['x'].shape[0]}\")\n",
    "        print(f\"num_edges: {dataset[0]['edge_index'].shape[1]}\")\n",
    "        print(f\"num_classes: {dataset.num_classes}\")\n",
    "        print(f\"num_features: {dataset.num_node_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7f8577f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([19717, 500]) torch.Size([19717, 19717])\n",
      "torch.Size([19717, 500]) torch.Size([19717, 19717])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "\n",
    "k_ = dataset.num_classes   #eigen value 0 lena hai itna\n",
    "\n",
    "adj = to_dense_adj(dataset[0].edge_index).cuda()\n",
    "adj = adj[0]\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "475846d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([19717, 500])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a0757b15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([19717, 19717])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0]).cuda()\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cad0c60b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "theta.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "114a42e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 19717\n"
     ]
    }
   ],
   "source": [
    "features = torch.Tensor(X).cuda()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2ba32732",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "features.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b33c30e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b7f9d72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.05)\n",
    "n = X.shape[1]\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1f724733",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.05)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones).cuda()\n",
    "      ones = ones.to_dense()\n",
    "      try:\n",
    "        C = convertScipyToTensor(C)\n",
    "        C = C.to_dense()\n",
    "      except:\n",
    "        C=C\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"GPU is available\")\n",
    "        C = C.cuda()\n",
    "        theta = theta.cuda()\n",
    "        X = X.cuda()\n",
    "        ones = ones.cuda()\n",
    "      def update_U(C,theta):\n",
    "          CT= torch.transpose(C,0,1)\n",
    "          lamb,U=torch.linalg.eig(CT@theta@C)  #U lena ahi\n",
    "          return U   \n",
    "      def bracket_term2fun(C,CT,theta):\n",
    "          U  = update_U(C,theta).double()\n",
    "          UT= torch.transpose(U,0,1)\n",
    "          Lw = (CT @theta @C).double()\n",
    "#           k_ = 7   #%notebookNumber of classes\n",
    "          lb= 1e-5\n",
    "          ub = 1e+4\n",
    "          beta = 0.5 \n",
    "          lambda_ =  laplacian_lambda_update(lb, ub, beta, U, Lw, k_)   \n",
    "          lambda_matrix =  torch.diag(lambda_,0).cuda()\n",
    "          return U@lambda_matrix@UT\n",
    "\n",
    "      def update_C(C):\n",
    "          thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = alpha_param*(C@ones)\n",
    "          bracket_term1 = (CT@thetaC)\n",
    "          bracket_term2 = bracket_term2fun(C,CT,theta) \n",
    "          bracket_term = bracket_term1 - bracket_term2   # bracket term (CT*theta*C - U*lambda*UT)\n",
    "          t2 = beta_param*(theta@C@bracket_term.float())\n",
    "          grad_fc= t1+t2\n",
    "          C_new=C-gamma_param*grad_fc\n",
    "          C_new[C_new<thresh] = thresh\n",
    "          for i in range(len(C_new)):\n",
    "              C_new[i] = C_new[i]/torch.linalg.norm(C_new[i],1)\n",
    "          return C_new        \n",
    "            \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "      #We set c1 = 10−5 and c2 = 10^4 We observed that the experimental performances of the algorithms \n",
    "       #are not sensitive to different values of c1 and c2 as long as they are reasonably small and large,respectively\n",
    "      # K is the number of smallest eigenvalues of the Laplacian matrix that are being ignored while updating the eigenvalues.\n",
    "      def laplacian_lambda_update(lb, ub, beta, U, Lw, k):\n",
    "        q = Lw.size(1) - k\n",
    "        UT= torch.transpose(U,0,1)\n",
    "        UT = UT.type(torch.float64)\n",
    "        d = torch.diag(UT @ Lw @ U)\n",
    "        # unconstrained solution as initial point\n",
    "        lambda_ = 0.5 * (d + torch.sqrt(d.pow(2) + 4 / beta))\n",
    "        lambda_,indices = torch.sort(lambda_, dim=- 1, descending=True)\n",
    "        eps = 1\n",
    "        condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:(q)] - lambda_[0:(q-1)]) >= -eps])\n",
    "        if condition:\n",
    "            return lambda_\n",
    "        else:\n",
    "            greater_ub = lambda_ > ub\n",
    "            lesser_lb = lambda_ < lb\n",
    "            lambda_[greater_ub] = ub\n",
    "            lambda_[lesser_lb] = lb\n",
    "            condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:q] - lambda_[0:(q-1)]) >= -eps])\n",
    "            if condition:\n",
    "                return lambda_\n",
    "            else:\n",
    "                print(lambda_)\n",
    "                raise ValueError(\"eigenvalues are not in increasing order, consider increasing the value of beta\")\n",
    "            \n",
    "\n",
    "      for i in tqdm(range(10)): #update C only 21\n",
    "         C = update_C(C)\n",
    "            \n",
    "      return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "006d1aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(X.shape[1], 64)\n",
    "        self.conv2 = GCNConv(64, NO_OF_CLASSES)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.conv1.reset_parameters()\n",
    "        self.conv2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    \n",
    "    \n",
    "####### NO output layer is written\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1c5a519a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accu(C_0,L,X_t_0):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        from scipy import sparse\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        Wc=Wc.toarray()\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        model=Net().to(device)\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        features_= features.cpu().detach().numpy()\n",
    "        try:\n",
    "          X=np.array(features_.todense())\n",
    "        except:\n",
    "          X = np.array(features_)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(100):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        adj_ = adj.cpu().detach().numpy()\n",
    "        Wc=sparse.csr_matrix(adj_)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "556afd48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:25<00:00,  2.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000,loss: 1.0984\n",
      "Time taken: 46.09373497962952\n",
      "Accuracy = 0.8183800781051884 100 0.01 0.001\n",
      "Average accuracy = 81.83800781051885 +/- 0.0\n",
      "Params =  100 0.01 0.001\n"
     ]
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pylab as plt\n",
    "import time\n",
    "        \n",
    "        \n",
    "highest_accuracy=0\n",
    "lambda_param = 0.001\n",
    "#0.0001,0.0001,10,0.0001\n",
    "for alpha_param in [100]:\n",
    "  for beta_param in [ 0.01]:\n",
    "      for gamma_param in [0.001]:\n",
    "            \n",
    "        av = []\n",
    "        for _ in range(1):\n",
    "            avg_accuracy_all=[]\n",
    "            X=X.cuda()\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "              a= time.time()\n",
    "              C_0 = experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "              b= time.time()\n",
    "              L = theta\n",
    "              pseudo_C = torch.linalg.pinv(C_0)\n",
    "              X_t_0 = pseudo_C@X\n",
    "              C_test = C_0.cpu().detach().numpy()\n",
    "              X_t_test = X_t_0.cpu().detach().numpy()\n",
    "              L_test = L.cpu().detach().numpy() \n",
    "              c= time.time()\n",
    "              acc = get_accu(C_test,L_test,X_t_test)\n",
    "              d= time.time()\n",
    "              print(\"Time taken:\", b-a+d-c)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8a93bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "highest_accuracy*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff99ff7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
