{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3f80e3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_15092\\273441208.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c287cb5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_15092\\4057615484.py:25: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "from random import sample\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv\n",
    "\n",
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3669a0f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Cora'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = os.path.join(os.getcwd(),'Cora')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "067c3d06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset: Cora\n",
      "num_nodes: 2708\n",
      "num_edges: 10556\n",
      "num_classes: 7\n",
      "num_features: 1433\n"
     ]
    }
   ],
   "source": [
    "import os.path as osp\n",
    "import torch\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "\n",
    "def get_planetoid_dataset(name, normalize_features=False, transform=None, split=\"public\"):\n",
    "    path = osp.join(osp.dirname(osp.realpath(os.getcwd())), '..', 'data', name)\n",
    "    if split == 'complete':\n",
    "        dataset = Planetoid(path, name)\n",
    "        dataset[0].train_mask.fill_(False)\n",
    "        dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1\n",
    "        dataset[0].val_mask.fill_(False)\n",
    "        dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1\n",
    "        dataset[0].test_mask.fill_(False)\n",
    "        dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1\n",
    "    else:\n",
    "        dataset = Planetoid(path, name, split=split)\n",
    "    if transform is not None and normalize_features:\n",
    "        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])\n",
    "    elif normalize_features:\n",
    "        dataset.transform = T.NormalizeFeatures()\n",
    "    elif transform is not None:\n",
    "        dataset.transform = transform\n",
    "    return dataset\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "#     lst_names = ['Cora', 'CiteSeer', 'PubMed']\n",
    "    lst_names = ['Cora']\n",
    "    for name in lst_names:\n",
    "        dataset = get_planetoid_dataset(name)\n",
    "        print(f\"dataset: {name}\")\n",
    "        print(f\"num_nodes: {dataset[0]['x'].shape[0]}\")\n",
    "        print(f\"num_edges: {dataset[0]['edge_index'].shape[1]}\")\n",
    "        print(f\"num_classes: {dataset.num_classes}\")\n",
    "        print(f\"num_features: {dataset.num_node_features}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f8577f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])\n",
      "torch.Size([2708, 1433]) torch.Size([2708, 2708])\n",
      "torch.Size([2708, 1433]) torch.Size([2708, 2708])\n"
     ]
    }
   ],
   "source": [
    "# from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "\n",
    "# # dataset = NELL(root='/nell')\n",
    "\n",
    "\n",
    "# dataset= Planetoid(root=dataset, name='Cora')\n",
    "print(dataset[0])\n",
    "adj = to_dense_adj(dataset[0].edge_index).cuda()\n",
    "adj = adj[0]\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "475846d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2708, 1433])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a0757b15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2708, 2708])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0]).cuda()\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cad0c60b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "theta.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "114a42e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7 2708\n"
     ]
    }
   ],
   "source": [
    "features = torch.Tensor(X).cuda()\n",
    "NO_OF_NODES = X.shape[0]\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2ba32732",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "features.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b33c30e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b7f9d72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.1)\n",
    "n = X.shape[1]\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1f724733",
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.1)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones).cuda()\n",
    "      ones = ones.to_dense()\n",
    "      print(ones.shape)\n",
    "      \n",
    "      try:\n",
    "        C = convertScipyToTensor(C)\n",
    "        C = C.to_dense()\n",
    "      except:\n",
    "        C=C\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"yes\")\n",
    "        C = C.cuda()\n",
    "        theta = theta.cuda()\n",
    "        X = X.cuda()\n",
    "        ones = ones.cuda()\n",
    "      def update_U(C,theta):\n",
    "          CT= torch.transpose(C,0,1)\n",
    "          lamb,U=torch.linalg.eig(CT@theta@C)  #U lena ahi\n",
    "          return U   \n",
    "      def bracket_term2fun(C,CT,theta):\n",
    "          U  = update_U(C,theta).double() \n",
    "          UT= torch.transpose(U,0,1)\n",
    "          Lw = (CT @theta @C).double()\n",
    "          k_ = 7   #%notebookNumber of classes\n",
    "          lb= 1e-5\n",
    "          ub = 1e+4\n",
    "          beta = 0.5 \n",
    "          lambda_ =  laplacian_lambda_update(lb, ub, beta, U, Lw, k_)   \n",
    "          lambda_matrix =  torch.diag(lambda_,0).cuda()\n",
    "          return U@lambda_matrix@UT\n",
    "\n",
    "      def update_C(C):\n",
    "#           thetaC = theta@C\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          t1 = alpha_param*(C@ones)\n",
    "          bracket_term1 = (CT@theta@C)\n",
    "          bracket_term2 = bracket_term2fun(C,CT,theta) \n",
    "          bracket_term = bracket_term1 - bracket_term2   # bracket term (CT*theta*C - U*lambda*UT)\n",
    "          t2 = beta_param*(theta@C@bracket_term.float())\n",
    "          grad_fc= t1+t2\n",
    "          C_new=C-gamma_param*grad_fc\n",
    "          C_new[C_new<thresh] = thresh\n",
    "          for i in range(len(C_new)):\n",
    "              C_new[i] = C_new[i]/torch.linalg.norm(C_new[i],1)\n",
    "          return C_new        \n",
    "            \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "      #We set c1 = 10−5 and c2 = 10^4 We observed that the experimental performances of the algorithms \n",
    "       #are not sensitive to different values of c1 and c2 as long as they are reasonably small and large,respectively\n",
    "      # K is the number of smallest eigenvalues of the Laplacian matrix that are being ignored while updating the eigenvalues.\n",
    "      def laplacian_lambda_update(lb, ub, beta, U, Lw, k):\n",
    "        q = Lw.size(1) - k\n",
    "        UT= torch.transpose(U,0,1)\n",
    "        UT = UT.type(torch.float64)\n",
    "        d = torch.diag(UT @ Lw @ U)\n",
    "        # unconstrained solution as initial point\n",
    "        lambda_ = 0.5 * (d + torch.sqrt(d.pow(2) + 4 / beta))\n",
    "        lambda_,indices = torch.sort(lambda_, dim=- 1, descending=True)\n",
    "        eps = 1\n",
    "        condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:(q)] - lambda_[0:(q-1)]) >= -eps])\n",
    "        if condition:\n",
    "            return lambda_\n",
    "        else:\n",
    "            greater_ub = lambda_ > ub\n",
    "            lesser_lb = lambda_ < lb\n",
    "            lambda_[greater_ub] = ub\n",
    "            lambda_[lesser_lb] = lb\n",
    "            condition = torch.tensor([(lambda_[q] - ub) <= eps,\n",
    "                                  (lambda_[0] - lb) >= -eps]).all()\n",
    "#                                   (lambda_[1:q] - lambda_[0:(q-1)]) >= -eps])\n",
    "            if condition:\n",
    "                return lambda_\n",
    "            else:\n",
    "                print(lambda_)\n",
    "                raise ValueError(\"eigenvalues are not in increasing order, consider increasing the value of beta\")\n",
    "            \n",
    "\n",
    "      for i in tqdm(range(20)): #update C only 21\n",
    "         C = update_C(C)\n",
    "            \n",
    "      return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e6f7fc35",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(C_0,L,X_t_0):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        # print(C_0_new)\n",
    "        # C_0_new=C_0\n",
    "        from scipy import sparse\n",
    "        #Lc=C_0.T@L@C_0\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        # print(\"L:\", Lc.shape)\n",
    "        # Lc=L_new\n",
    "        #print(Lc)\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        # print(\"W:\", Wc.shape)\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        #print(\"edgecoarsen:\", edge_index_coarsen2.shape)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        #print(\"edgeweight:\", edge_weight.shape)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "#         device = torch.device('cpu')\n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        #print(\"Y:\", Y.shape)\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        # NO_OF_CLASSES=Y.shape[1]\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        #print(\"Lables:\", labels_coarse.shape)\n",
    "\n",
    "        #torch.Tensor(C2)@X\n",
    "        Wc=Wc.toarray()\n",
    "        #Wc[Wc<0.01]=0\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        model=Net().to(device)\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        try:\n",
    "          X=np.array(features.todense())\n",
    "        except:\n",
    "          X = np.array(features)\n",
    "        #print(\"X:\",X.shape)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        # criterion=torch.nn.CrossEntropyLoss()\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        # Xt=X_t_0\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(100):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        Wc=sparse.csr_matrix(adj)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "556afd48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([270, 270])\n",
      "yes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.88it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\numpy\\core\\fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[1;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[0;32m     56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 57\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m bound(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m     58\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m     59\u001b[0m     \u001b[38;5;66;03m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[0;32m     60\u001b[0m     \u001b[38;5;66;03m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     64\u001b[0m     \u001b[38;5;66;03m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[0;32m     65\u001b[0m     \u001b[38;5;66;03m# exception has a traceback chain.\u001b[39;00m\n",
      "\u001b[1;31mTypeError\u001b[0m: argmax() got an unexpected keyword argument 'axis'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn [16], line 23\u001b[0m\n\u001b[0;32m     21\u001b[0m X_t_0 \u001b[38;5;241m=\u001b[39m pseudo_C\u001b[38;5;129m@X\u001b[39m\n\u001b[0;32m     22\u001b[0m C_t_0\u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtranspose(C_0,\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 23\u001b[0m acc \u001b[38;5;241m=\u001b[39m \u001b[43mget_accuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mC_0\u001b[49m\u001b[43m,\u001b[49m\u001b[43mL\u001b[49m\u001b[43m,\u001b[49m\u001b[43mX_t_0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     24\u001b[0m av\u001b[38;5;241m.\u001b[39mappend(acc)\n\u001b[0;32m     25\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m highest_accuracy\u001b[38;5;241m<\u001b[39macc:\n",
      "Cell \u001b[1;32mIn [15], line 7\u001b[0m, in \u001b[0;36mget_accuracy\u001b[1;34m(C_0, L, X_t_0)\u001b[0m\n\u001b[0;32m      5\u001b[0m C_0_new\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39mzeros(C_0\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m      6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(C_0\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]):\n\u001b[1;32m----> 7\u001b[0m     C_0_new[i][\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43mC_0\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m]\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[0;32m      8\u001b[0m \u001b[38;5;66;03m# print(C_0_new)\u001b[39;00m\n\u001b[0;32m      9\u001b[0m \u001b[38;5;66;03m# C_0_new=C_0\u001b[39;00m\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mscipy\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m sparse\n",
      "File \u001b[1;32m<__array_function__ internals>:5\u001b[0m, in \u001b[0;36margmax\u001b[1;34m(*args, **kwargs)\u001b[0m\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\numpy\\core\\fromnumeric.py:1195\u001b[0m, in \u001b[0;36margmax\u001b[1;34m(a, axis, out)\u001b[0m\n\u001b[0;32m   1121\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_argmax_dispatcher)\n\u001b[0;32m   1122\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21margmax\u001b[39m(a, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, out\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m   1123\u001b[0m     \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m   1124\u001b[0m \u001b[38;5;124;03m    Returns the indices of the maximum values along an axis.\u001b[39;00m\n\u001b[0;32m   1125\u001b[0m \n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1193\u001b[0m \n\u001b[0;32m   1194\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m-> 1195\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_wrapfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43margmax\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\numpy\\core\\fromnumeric.py:66\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[1;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[0;32m     57\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m bound(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m     58\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m     59\u001b[0m     \u001b[38;5;66;03m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[0;32m     60\u001b[0m     \u001b[38;5;66;03m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     64\u001b[0m     \u001b[38;5;66;03m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[0;32m     65\u001b[0m     \u001b[38;5;66;03m# exception has a traceback chain.\u001b[39;00m\n\u001b[1;32m---> 66\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _wrapit(obj, method, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\numpy\\core\\fromnumeric.py:43\u001b[0m, in \u001b[0;36m_wrapit\u001b[1;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[0;32m     41\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[0;32m     42\u001b[0m     wrap \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m---> 43\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m, method)(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[0;32m     44\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m wrap:\n\u001b[0;32m     45\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(result, mu\u001b[38;5;241m.\u001b[39mndarray):\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\_tensor.py:757\u001b[0m, in \u001b[0;36mTensor.__array__\u001b[1;34m(self, dtype)\u001b[0m\n\u001b[0;32m    755\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(Tensor\u001b[38;5;241m.\u001b[39m__array__, (\u001b[38;5;28mself\u001b[39m,), \u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mdtype)\n\u001b[0;32m    756\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 757\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    758\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    759\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mastype(dtype, copy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
      "\u001b[1;31mTypeError\u001b[0m: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first."
     ]
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pylab as plt\n",
    "        # sns.heatmap(C_0.T@C_0)\n",
    "        \n",
    "        \n",
    "highest_accuracy=0\n",
    "lambda_param = 0.001\n",
    "#0.0001,0.0001,10,0.0001\n",
    "for alpha_param in [100,10,1,0.1,0.01]:\n",
    "  for beta_param in [ 100,10,1,0.1,0.01]:\n",
    "      for gamma_param in [100,10,1,0.1,0.01]:\n",
    "\n",
    "        av = []\n",
    "        avg_accuracy_all=[]\n",
    "        X=X.cuda()\n",
    "        for _ in range(1):\n",
    "          C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "          C_0 = experiment_K_component(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X)\n",
    "          L = theta\n",
    "          pseudo_C = torch.linalg.pinv(C_0)\n",
    "          X_t_0 = pseudo_C@X\n",
    "          C_t_0= torch.transpose(C_0,0,1)\n",
    "          acc = get_accuracy(C_0,L,X_t_0)\n",
    "          av.append(acc)\n",
    "          if highest_accuracy<acc:\n",
    "            highest_accuracy=acc\n",
    "            print(\"Accuracy = \" + str(acc) + \" \" + str(lambda_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8ecd8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = torch.randn(30,30).cuda()\n",
    "B= torch.randn(30,30).cuda()\n",
    "D=  torch.randn(30,30).cuda()\n",
    "anss=  experiment_K_component(0.01,0.001,0.01,0.01,A,B,D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd227935",
   "metadata": {},
   "outputs": [],
   "source": [
    "anss.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed03df49",
   "metadata": {},
   "outputs": [],
   "source": [
    "p=2708\n",
    "k = int(p*0.02)\n",
    "zeros = csr_matrix(np.zeros((p,k)))\n",
    "zeros = convertScipyToTensor(zeros)\n",
    "zeros = zeros.to_dense()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d892e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "zeros.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c5a519a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
