{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3f80e3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_28276\\273441208.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c287cb5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_28276\\4057615484.py:25: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "from random import sample\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv\n",
    "\n",
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3669a0f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\PubMed'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'PubMed')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "067c3d06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset: PubMed\n",
      "num_nodes: 19717\n",
      "num_edges: 88648\n",
      "num_classes: 3\n",
      "num_features: 500\n"
     ]
    }
   ],
   "source": [
    "import os.path as osp\n",
    "import torch\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def get_planetoid_dataset(name, normalize_features=False, transform=None, split=\"public\"):\n",
    "    path = osp.join(osp.dirname(osp.realpath(os.getcwd())), '..', 'data', name)\n",
    "    if split == 'complete':\n",
    "        dataset = Planetoid(path, name)\n",
    "        dataset[0].train_mask.fill_(False)\n",
    "        dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1\n",
    "        dataset[0].val_mask.fill_(False)\n",
    "        dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1\n",
    "        dataset[0].test_mask.fill_(False)\n",
    "        dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1\n",
    "    else:\n",
    "        dataset = Planetoid(path, name, split=split)\n",
    "    if transform is not None and normalize_features:\n",
    "        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])\n",
    "    elif normalize_features:\n",
    "        dataset.transform = T.NormalizeFeatures()\n",
    "    elif transform is not None:\n",
    "        dataset.transform = transform\n",
    "    return dataset\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "#     lst_names = ['Cora', 'CiteSeer', 'PubMed']\n",
    "    lst_names = ['PubMed']\n",
    "    for name in lst_names:\n",
    "        dataset = get_planetoid_dataset(name)\n",
    "        print(f\"dataset: {name}\")\n",
    "        print(f\"num_nodes: {dataset[0]['x'].shape[0]}\")\n",
    "        print(f\"num_edges: {dataset[0]['edge_index'].shape[1]}\")\n",
    "        print(f\"num_classes: {dataset.num_classes}\")\n",
    "        print(f\"num_features: {dataset.num_node_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f8577f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([19717, 500]) torch.Size([19717, 19717])\n",
      "torch.Size([19717, 500]) torch.Size([19717, 19717])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "\n",
    "k_ = dataset.num_classes   #eigen value 0 lena hai itna\n",
    "\n",
    "adj = to_dense_adj(dataset[0].edge_index).cuda()\n",
    "adj = adj[0]\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "475846d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([19717, 500])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a0757b15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([19717, 19717])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0]).cuda()\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cad0c60b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "theta.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "114a42e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 19717\n"
     ]
    }
   ],
   "source": [
    "features = torch.Tensor(X).cuda()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2ba32732",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "features.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b33c30e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b7f9d72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.1)\n",
    "n = X.shape[1]\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1f724733",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.1)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones).cuda()\n",
    "      ones = ones.to_dense()\n",
    "      try:\n",
    "        C = convertScipyToTensor(C)\n",
    "        C = C.to_dense()\n",
    "      except:\n",
    "        C=C\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"GPU is available\")\n",
    "        C = C.cuda()\n",
    "        theta = theta.cuda()\n",
    "        X = X.cuda()\n",
    "        ones = ones.cuda()\n",
    "      def update_U(C,theta):\n",
    "          CT= torch.transpose(C,0,1)\n",
    "          lamb,U=torch.linalg.eig(CT@theta@C)  #U lena ahi\n",
    "          return U   \n",
    "      def bracket_term2fun(C,CT,theta):\n",
    "          U  = update_U(C,theta).double()\n",
    "          UT= torch.transpose(U,0,1)\n",
    "          Lw = (CT @theta @C).double()\n",
    "#           k_ = 7   #%notebookNumber of classes\n",
    "          lb= 1e-5\n",
    "          ub = 1e+4\n",
    "          beta = 0.5 \n",
    "          lambda_ =  laplacian_lambda_update(lb, ub, beta, U, Lw, k_)   \n",
    "          lambda_matrix =  torch.diag(lambda_,0).cuda()\n",
    "          return U@lambda_matrix@UT\n",
    "\n",
    "      def update_C(C):\n",
    "          thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = alpha_param*(C@ones)\n",
    "          bracket_term1 = (CT@thetaC)\n",
    "          bracket_term2 = bracket_term2fun(C,CT,theta) \n",
    "          bracket_term = bracket_term1 - bracket_term2   # bracket term (CT*theta*C - U*lambda*UT)\n",
    "          t2 = beta_param*(theta@C@bracket_term.float())\n",
    "          grad_fc= t1+t2\n",
    "          C_new=C-gamma_param*grad_fc\n",
    "          C_new[C_new<thresh] = thresh\n",
    "          for i in range(len(C_new)):\n",
    "              C_new[i] = C_new[i]/torch.linalg.norm(C_new[i],1)\n",
    "          return C_new        \n",
    "            \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "      #We set c1 = 10−5 and c2 = 10^4 We observed that the experimental performances of the algorithms \n",
    "       #are not sensitive to different values of c1 and c2 as long as they are reasonably small and large,respectively\n",
    "      # K is the number of smallest eigenvalues of the Laplacian matrix that are being ignored while updating the eigenvalues.\n",
    "      def laplacian_lambda_update(lb, ub, beta, U, Lw, k):\n",
    "        q = Lw.size(1) - k\n",
    "        UT= torch.transpose(U,0,1)\n",
    "        UT = UT.type(torch.float64)\n",
    "        d = torch.diag(UT @ Lw @ U)\n",
    "        # unconstrained solution as initial point\n",
    "        lambda_ = 0.5 * (d + torch.sqrt(d.pow(2) + 4 / beta))\n",
    "        lambda_,indices = torch.sort(lambda_, dim=- 1, descending=True)\n",
    "        eps = 1\n",
    "        condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:(q)] - lambda_[0:(q-1)]) >= -eps])\n",
    "        if condition:\n",
    "            return lambda_\n",
    "        else:\n",
    "            greater_ub = lambda_ > ub\n",
    "            lesser_lb = lambda_ < lb\n",
    "            lambda_[greater_ub] = ub\n",
    "            lambda_[lesser_lb] = lb\n",
    "            condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:q] - lambda_[0:(q-1)]) >= -eps])\n",
    "            if condition:\n",
    "                return lambda_\n",
    "            else:\n",
    "                print(lambda_)\n",
    "                raise ValueError(\"eigenvalues are not in increasing order, consider increasing the value of beta\")\n",
    "            \n",
    "\n",
    "      for i in tqdm(range(20)): #update C only 21\n",
    "         C = update_C(C)\n",
    "            \n",
    "      return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "006d1aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(X.shape[1], 64)\n",
    "        self.conv2 = GCNConv(64, NO_OF_CLASSES)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.conv1.reset_parameters()\n",
    "        self.conv2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    \n",
    "    \n",
    "####### NO output layer is written\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1c5a519a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accu(C_0,L,X_t_0):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        from scipy import sparse\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        Wc=Wc.toarray()\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        model=Net().to(device)\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        features_= features.cpu().detach().numpy()\n",
    "        try:\n",
    "          X=np.array(features_.todense())\n",
    "        except:\n",
    "          X = np.array(features_)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(100):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        adj_ = adj.cpu().detach().numpy()\n",
    "        Wc=sparse.csr_matrix(adj_)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "556afd48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                           | 0/20 [00:00<?, ?it/s]C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_28276\\2449261001.py:33: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at  C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\Copy.cpp:250.)\n",
      "  U  = update_U(C,theta).double()\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [02:23<00:00,  7.19s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unexpected exception formatting exception. Falling back to standard exception\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 3433, in run_code\n",
      "    exec(code_obj, self.user_global_ns, self.user_ns)\n",
      "  File \"C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_28276\\1929427005.py\", line 23, in <module>\n",
      "    C_test = C_0.cpu().detach().numpy()\n",
      "RuntimeError: [enforce fail at C:\\cb\\pytorch_1000000000000\\work\\c10\\core\\impl\\alloc_cpu.cpp:81] data. DefaultCPUAllocator: not enough memory: you tried to allocate 155448828 bytes.\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 317, in executing\n",
      "    args = executing_cache[key]\n",
      "KeyError: (<code object run_code at 0x0000014844009D40, file \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 3397>, 1409890164032, 74)\n",
      "\n",
      "During handling of the above exception, another exception occurred:\n",
      "\n",
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\", line 2052, in showtraceback\n",
      "    stb = self.InteractiveTB.structured_traceback(\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\ultratb.py\", line 1112, in structured_traceback\n",
      "    return FormattedTB.structured_traceback(\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\ultratb.py\", line 1006, in structured_traceback\n",
      "    return VerboseTB.structured_traceback(\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\ultratb.py\", line 859, in structured_traceback\n",
      "    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\ultratb.py\", line 793, in format_exception_as_a_whole\n",
      "    self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\IPython\\core\\ultratb.py\", line 848, in get_records\n",
      "    return list(stack_data.FrameInfo.stack_data(etb, options=options))[tb_offset:]\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\stack_data\\core.py\", line 565, in stack_data\n",
      "    yield from collapse_repeated(\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\stack_data\\utils.py\", line 84, in collapse_repeated\n",
      "    yield from map(mapper, original_group)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\stack_data\\core.py\", line 555, in mapper\n",
      "    return cls(f, options)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\stack_data\\core.py\", line 520, in __init__\n",
      "    self.executing = Source.executing(frame_or_tb)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 369, in executing\n",
      "    args = find(source=cls.for_frame(frame), retry_cache=True)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 252, in for_frame\n",
      "    return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 270, in for_filename\n",
      "    result = source_cache[filename] = cls._for_filename_and_lines(filename, lines)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 281, in _for_filename_and_lines\n",
      "    result = source_cache[(filename, lines)] = cls(filename, lines)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\stack_data\\core.py\", line 79, in __init__\n",
      "    super(Source, self).__init__(*args, **kwargs)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\site-packages\\executing\\executing.py\", line 228, in __init__\n",
      "    self.tree = ast.parse(ast_text, filename=filename)\n",
      "  File \"C:\\Users\\Sandeep\\anaconda3\\lib\\ast.py\", line 50, in parse\n",
      "    return compile(source, filename, mode, flags,\n",
      "MemoryError\n"
     ]
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pylab as plt\n",
    "        # sns.heatmap(C_0.T@C_0)\n",
    "        \n",
    "        \n",
    "highest_accuracy=0\n",
    "lambda_param = 0.001\n",
    "#0.0001,0.0001,10,0.0001\n",
    "for alpha_param in [10,1,0.1,0.01,0.001]:\n",
    "  for beta_param in [ 100,10,1,0.1,0.01,0.001]:\n",
    "      for gamma_param in [100,10,1,0.1,0.01,0.001]:\n",
    "            \n",
    "        av = []\n",
    "        for _ in range(2):\n",
    "            avg_accuracy_all=[]\n",
    "            X=X.cuda()\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "              C_0 = experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "              L = theta\n",
    "              pseudo_C = torch.linalg.pinv(C_0)\n",
    "              X_t_0 = pseudo_C@X\n",
    "              C_test = C_0.cpu().detach().numpy()\n",
    "              X_t_test = X_t_0.cpu().detach().numpy()\n",
    "              L_test = L.cpu().detach().numpy() \n",
    "              acc = get_accu(C_test,L_test,X_t_test)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8a93bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "highest_accuracy*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff99ff7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
