{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3f80e3a6",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "3f80e3a6",
    "outputId": "688d4a21-c2c8-4090-8785-ba9fec49eaef"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\273441208.py:10: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# !pip install deeprobust\n",
    "# !conda install pytorch torchvision torchaudio -c pytorch\n",
    "import torch\n",
    "# print(torch.__version__)\n",
    "# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html\n",
    "# !pip install torch-geometric\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c287cb5c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "id": "c287cb5c",
    "outputId": "54552bfd-7681-4b66-bce2-74e2c38e399f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\anaconda3\\envs\\nimesh\\lib\\site-packages\\torch_geometric\\typing.py:56: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: [WinError 127] The specified procedure could not be found\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-scatter'. \"\n",
      "C:\\Users\\Sandeep\\anaconda3\\envs\\nimesh\\lib\\site-packages\\torch_geometric\\typing.py:93: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: [WinError 127] The specified procedure could not be found\n",
      "  warnings.warn(f\"An issue occurred while importing 'torch-sparse'. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\4057615484.py:25: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
      "  from IPython.core.display import display, HTML\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>.container { width:90% !important; }</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment Bipartite\\\\Co-CS'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import networkx as nx\n",
    "from networkx.algorithms import community\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "import numpy\n",
    "import torch\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline\n",
    "from IPython.core.display import display, HTML\n",
    "display(HTML(\"<style>.container { width:90% !important; }</style>\"))\n",
    "\n",
    "\n",
    "from random import sample\n",
    "from networkx.generators.random_graphs import erdos_renyi_graph\n",
    "from networkx.generators.random_graphs import barabasi_albert_graph\n",
    "from networkx.generators.community import stochastic_block_model\n",
    "from networkx.generators.random_graphs import watts_strogatz_graph\n",
    "from networkx.generators.community import random_partition_graph\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import FactorAnalysis\n",
    "import random\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.sparse.linalg import inv\n",
    "\n",
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3669a0f9",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "3669a0f9",
    "outputId": "213cddc0-551b-4948-a025-9704edf67d7e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'C:\\\\Users\\\\Sandeep\\\\Downloads\\\\Subhanu_ RESULTS\\\\FGC\\\\Experiment Bipartite\\\\Co-CS\\\\CS'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dataset = os.path.join(os.getcwd(),'Cora')\n",
    "# dataset\n",
    "dataset = os.path.join(os.getcwd(),'CS')\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7f8577f3",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7f8577f3",
    "outputId": "de9c336b-dbd9-40ca-e1fa-58107ade7ce8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[18333, 6805], edge_index=[2, 163788], y=[18333])\n",
      "torch.Size([18333, 6805]) torch.Size([18333, 18333])\n",
      "torch.Size([18333, 6805]) torch.Size([18333, 18333])\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.datasets import Coauthor,Planetoid,CitationFull,WebKB\n",
    "from torch_geometric.utils import to_dense_adj,add_random_edge\n",
    "\n",
    "# dataset = Planetoid(root='/pubmed',name='PubMed') #change for Cora, Citeseer, PubMed\n",
    "# dataset = CitationFull(root='/dblp',name='DBLP')\n",
    "dataset = Coauthor(root=dataset,name='CS') # change for CS, Physics\n",
    "print(dataset[0])\n",
    "\n",
    "edge_index = dataset[0].edge_index\n",
    "\n",
    "edge_index, added_edges = add_random_edge(edge_index, p=0.2,force_undirected=True)\n",
    "\n",
    "adj = to_dense_adj(edge_index)\n",
    "adj = adj[0]\n",
    "\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "N = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "\n",
    "print(X.shape, adj.shape)\n",
    "\n",
    "nn = int(1*N)\n",
    "X = X[:nn,:]\n",
    "adj = adj[:nn,:nn]\n",
    "A = adj[:nn,:nn]\n",
    "AT= torch.transpose(A,0,1)\n",
    "labels = labels[:nn]\n",
    "print(X.shape,adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "475846d5",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "475846d5",
    "outputId": "c6102e64-ae54-46d6-8f9b-48ba85a81564"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([18333, 6805])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a0757b15",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "a0757b15",
    "outputId": "d35c00b4-f747-4b97-c075-b98efcdb77eb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18333, 18333])\n"
     ]
    }
   ],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0])\n",
    "    return torch.diag(adj@b)-adj\n",
    "\n",
    "theta = get_laplacian(adj)\n",
    "print(theta.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cad0c60b",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cad0c60b",
    "outputId": "4de3ae5a-7f39-4649-bae9-c437127dc115"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "theta.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "114a42e0",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "114a42e0",
    "outputId": "62b0a402-9d33-44be-8e7f-20de623be138"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15 18333\n"
     ]
    }
   ],
   "source": [
    "\n",
    "features = torch.Tensor(X)\n",
    "NO_OF_NODES = X.shape[0]\n",
    "print(NO_OF_CLASSES,NO_OF_NODES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2ba32732",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2ba32732",
    "outputId": "4d8096c2-cca6-41e5-f7f9-ebd3e815a404"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## Delete later\n",
    "features.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b33c30e2",
   "metadata": {
    "id": "b33c30e2"
   },
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b7f9d72b",
   "metadata": {
    "id": "b7f9d72b"
   },
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "from scipy.sparse.linalg import norm\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "p = X.shape[0]\n",
    "k = int(p*0.3)\n",
    "n = X.shape[1]\n",
    "lr = 1e-5\n",
    "thresh = 1e-10\n",
    "\n",
    "from scipy.sparse import random\n",
    "from scipy.stats import rv_continuous\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "C = random(p, k, density=0.25, random_state=1, data_rvs=temp2.rvs)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "eecc9adb",
   "metadata": {
    "id": "eecc9adb"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(X.shape[1], 64)\n",
    "        self.conv2 = GCNConv(64, NO_OF_CLASSES)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.conv1.reset_parameters()\n",
    "        self.conv2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    \n",
    "    \n",
    "####### NO output layer is written\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1c5a519a",
   "metadata": {
    "id": "1c5a519a"
   },
   "outputs": [],
   "source": [
    "def get_accu(C_0,L,X_t_0):\n",
    "    global labels, NO_OF_CLASSES,k\n",
    "    t=[]\n",
    "    for i in [1]: \n",
    "        C_0_new=np.zeros(C_0.shape)\n",
    "        for i in range(C_0.shape[0]):\n",
    "            C_0_new[i][np.argmax(C_0[i])]=1\n",
    "        # print(C_0_new)\n",
    "        # C_0_new=C_0\n",
    "        from scipy import sparse\n",
    "        #Lc=C_0.T@L@C_0\n",
    "        Lc=C_0_new.T@L@C_0_new\n",
    "        # print(\"L:\", Lc.shape)\n",
    "        # Lc=L_new\n",
    "        #print(Lc)\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        # print(\"W:\", Wc.shape)\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        #print(\"edgecoarsen:\", edge_index_coarsen2.shape)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        #print(\"edgeweight:\", edge_weight.shape)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "#         device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        #print(\"Y:\", Y.shape)\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        # NO_OF_CLASSES=Y.shape[1]\n",
    "        P=np.linalg.pinv(C_0_new)\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(P).double() , Y.double()).double() , 1)\n",
    "        #print(\"Lables:\", labels_coarse.shape)\n",
    "\n",
    "        #torch.Tensor(C2)@X\n",
    "        Wc=Wc.toarray()\n",
    "        #Wc[Wc<0.01]=0\n",
    "        C2=np.linalg.pinv(C_0_new)\n",
    "#         device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        model=Net().to(device)\n",
    "        lr=0.01\n",
    "        decay=0.0001\n",
    "        features_= features.cpu().detach().numpy()\n",
    "        try:\n",
    "          X=np.array(features_.todense())\n",
    "        except:\n",
    "          X = np.array(features_)\n",
    "        #print(\"X:\",X.shape)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        # criterion=torch.nn.CrossEntropyLoss()\n",
    "        x=sample(range(0, int(k)), k)\n",
    "      \n",
    "        from datetime import datetime\n",
    "        Xt=P@X\n",
    "        # Xt=X_t_0\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(torch.Tensor(Xt).to(device),edge_index_coarsen2)\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(200):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "            if(epoch%100==0):\n",
    "                print(f'Epoch: {epoch:03d},loss: {loss:.4f}')\n",
    "        now2 = datetime.now()        \n",
    "        pred=model(torch.Tensor(Xt).to(device),edge_index_coarsen2).argmax(dim=1)        \n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "    \n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        adj_ = adj.cpu().detach().numpy()\n",
    "        Wc=sparse.csr_matrix(adj_)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "        return acc\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6443025a",
   "metadata": {},
   "outputs": [],
   "source": [
    "k_=NO_OF_CLASSES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1f724733",
   "metadata": {
    "id": "1f724733"
   },
   "outputs": [],
   "source": [
    "def experiment_structure(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X,A):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*0.3)\n",
    "      n = X.shape[1]\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones)  \n",
    "      ones = ones.to_dense()\n",
    "      \n",
    "      try:\n",
    "        C = convertScipyToTensor(C)\n",
    "        C = C.to_dense()\n",
    "      except:\n",
    "        C=C\n",
    "      try:\n",
    "        theta = convertScipyToTensor(theta)\n",
    "      except:\n",
    "        theta = theta\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "      if(torch.cuda.is_available()):\n",
    "        print(\"GPU is available\")\n",
    "        C = C  \n",
    "        theta = theta  \n",
    "        X = X  \n",
    "        ones = ones  \n",
    "      \n",
    "        \n",
    "      def bracket_term2fun(C,CT,theta):\n",
    "          # U  = update_V(C,theta).double()\n",
    "          U = torch.stack(update_V(C, theta)).double()\n",
    "          UT= torch.transpose(U,0,1)\n",
    "          Lw = (CT @theta @C).double()\n",
    "          lb= 1e-5\n",
    "          ub = 1e+4\n",
    "          beta = 0.5 \n",
    "          lambda_ =  laplacian_lambda_update(lb, ub, beta, U, Lw, k_,C)   \n",
    "          lambda_matrix =  torch.diag(lambda_,0)  \n",
    "          # print(lambda_matrix)\n",
    "          # print(f'Shape of U is: {U.shape}')\n",
    "          # print(f'Shape of lambda: {}')\n",
    "          # print(f'Shape of lambda_matrix is: {lambda_matrix.shape}')\n",
    "          return UT@lambda_matrix@U\n",
    "        \n",
    "      def update_V(C,theta):\n",
    "            \n",
    "        CT= torch.transpose(C,0,1)\n",
    "        product = CT @ A @ C               \n",
    "        matrix = torch.tensor(product)         \n",
    "        eigenvalues, eigenvectors = torch.linalg.eig(product)\n",
    "\n",
    "        # select non-zero eigenvalues and eigenvectors\n",
    "        non_zero_eigenvalues = []\n",
    "        non_zero_eigenvectors = []\n",
    "          \n",
    "        # for i in range(matrix.shape[0]):\n",
    "        #     if matrix[i, i] != 0:\n",
    "        #         non_zero_eigenvalues.append(eigenvalues[i])\n",
    "        #         non_zero_eigenvectors.append(eigenvectors[:, i])\n",
    "        # U = [torch.tensor(eigenvector) for eigenvector in non_zero_eigenvectors]\n",
    "        # return U    \n",
    "\n",
    "        for i in range(matrix.shape[0]):\n",
    "            if eigenvalues[i] != 0:\n",
    "                non_zero_eigenvalues.append(eigenvalues[i])\n",
    "                non_zero_eigenvectors.append(eigenvectors[:, i])\n",
    "        V = [torch.tensor(eigenvector) for eigenvector in non_zero_eigenvectors]\n",
    "        return V \n",
    " \n",
    "\n",
    "      def update_C(C):\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          C.size()\n",
    "          t1 = alpha_param*(C@ones)  \n",
    "          bracket_term1 = (CT@theta@C)\n",
    "          bracket_term3 = (CT@A@C)\n",
    "          bracket_term2 = bracket_term2fun(C,CT,theta) \n",
    "          bracket_term = bracket_term3 - bracket_term2             # bracket term (CT*A*C - U*lambda*UT)\n",
    "          t22 = -2*(theta@C)                                      # Check for multiplication of gamma\n",
    "#           print(t22.type())\n",
    "          t3 = bracket_term1\n",
    "          t7 = bracket_term2\n",
    "          t6 = (CT@A@C)  \n",
    "          t5 = 2* beta_param*(A@C)\n",
    "          t5 = t5.float()\n",
    "          t4 = (1.0/k)\n",
    "          t44 = t4*((torch.ones(k,k)).double())  \n",
    "#           print(t3.device)\n",
    "#           print(t44.device)\n",
    "          t8 = (t3 + t44)  \n",
    "          t9 = torch.pinverse(t8)                  \n",
    "          t9 = t9.float()\n",
    "#           print(t9.type())\n",
    "#           print(t9)x\n",
    "          t10 = (t22@t9)  \n",
    "          t11 = (t6 - t7)  \n",
    "          t11 = t11.float()\n",
    "          t12 = (t5@t11)\n",
    "          t13 = (t1 + t10 +t12)  \n",
    "        \n",
    "          #t2 = beta_param*(theta@C@bracket_term.float())\n",
    "          grad_fc= t13\n",
    "          C_new=C-gamma_param*grad_fc\n",
    "          C_new[C_new<thresh] = thresh\n",
    "          for i in range(len(C_new)):\n",
    "              C_new[i] = C_new[i]/torch.linalg.norm(C_new[i],1)\n",
    "          return C_new        \n",
    "            \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "\n",
    "      #We set c1 = 10−5 and c2 = 10^4 We observed that the experimental performances of the algorithms \n",
    "       #are not sensitive to different values of c1 and c2 as long as they are reasonably small and large,respectively\n",
    "      # K is the number of smallest eigenvalues of the Laplacian matrix that are being ignored while updating the eigenvalues.\n",
    "      def laplacian_lambda_update(lb, ub, beta, U, Lw, k, C):\n",
    "        q = Lw.size(1) - k\n",
    "        # print(f'q is: {q}')\n",
    "        U = U\n",
    "        UT= torch.transpose(U,0,1)\n",
    "        UT = UT.type(torch.float64)\n",
    "        UT = UT\n",
    "        \n",
    "        CT= torch.transpose(C,0,1)\n",
    "        CT = CT.type(torch.float64)\n",
    "        CT = CT\n",
    "        \n",
    "        AC=(A@C).double()\n",
    "        AC = AC\n",
    "        \n",
    "        Af=(CT@AC).double()\n",
    "        Af = Af  \n",
    "        Af.device\n",
    "        U.device\n",
    "        dd = U@Af@UT\n",
    "          \n",
    "        # cc = UT@A@U\n",
    "        \n",
    "        product = dd\n",
    "        matrix = torch.tensor(product)     \n",
    "\n",
    "        non_zero_diag_elements = []\n",
    "        for i in range(matrix.shape[0]):\n",
    "            if matrix[i, i] != 0:\n",
    "                non_zero_diag_elements.append(matrix[i, i])\n",
    "            if len(non_zero_diag_elements) == len(matrix):\n",
    "                break\n",
    "\n",
    "        k = len(non_zero_diag_elements)\n",
    "        e = non_zero_diag_elements\n",
    "        d = torch.diag(torch.tensor(non_zero_diag_elements))\n",
    "#-----------------------------------------------------------------------------------------------------------------------------------------------\n",
    "\n",
    "       # Trial-2                 -########################################\n",
    "\n",
    "        e_bar = torch.tensor([])\n",
    "        if(k%2 == 0):\n",
    "            for i in range(k//2):\n",
    "                e_bar = torch.cat((e_bar, ((e[i] - e[k - i - 1]) / 2).unsqueeze(0)), dim=0)\n",
    "            \n",
    "        if(k%2 != 0):\n",
    "            for i in range((k+1)//2) :\n",
    "                e_bar = torch.cat((e_bar, ((e[i] - e[k - i - 1]) / 2).unsqueeze(0)), dim=0)       \n",
    "                \n",
    "        lambda_,indices = torch.sort(e_bar, dim=- 1, descending=True)\n",
    "        eps = 1\n",
    "        qq = lambda_.size(0)-1\n",
    "        condition = torch.stack([(lambda_[qq] - ub) <= eps,\n",
    "                         (lambda_[0] - lb) >= -eps]).all(dim=0)\n",
    "\n",
    "#                                   (lambda_[1:(q)] - lambda_[0:(q-1)]) >= -eps])\n",
    "        \n",
    "          \n",
    "        if condition.all():\n",
    "            # while(lambda_.size(0) != 135):\n",
    "            #     lambda_ = torch.cat((lambda_, torch.tensor(0).unsqueeze(0)), dim=0)\n",
    "            for i in range(k//2):\n",
    "                lambda_ = torch.cat((lambda_, -lambda_[(k//2)-1-i].unsqueeze(0)), dim=0)\n",
    "            # print(f'Shape of updated lambda1_ is: {lambda_.shape}')\n",
    "            # hm= sns.heatmap(data =lambda_)\n",
    "            # plt.show()\n",
    "            # print(lambda_)\n",
    "            return lambda_\n",
    "        else:\n",
    "            greater_ub = lambda_ > ub\n",
    "            lesser_lb = lambda_ < lb\n",
    "            lambda_[greater_ub] = ub\n",
    "            lambda_[lesser_lb] = lb\n",
    "            condition = torch.stack([(lambda_[qq] - ub) <= eps,\n",
    "                         (lambda_[0] - lb) >= -eps]).all(dim=0)\n",
    "        for i in range(k//2):\n",
    "            lambda_ = torch.cat((lambda_, -lambda_[(k//2)-1-i].unsqueeze(0)), dim=0)            \n",
    "            \n",
    "\n",
    "        print(f'Shape of updated lambda2_ is: {lambda_.shape}')\n",
    "        if condition.all():\n",
    "            return lambda_\n",
    "        else:\n",
    "#           print(lambda_)\n",
    "            raise ValueError(\"eigenvalues are not in increasing order, consider increasing the value of beta\")\n",
    "\n",
    "      for i in tqdm(range(10)): #update C only 21\n",
    "         C = update_C(C)\n",
    "            \n",
    "      return C\n",
    "          \n",
    "#-----------------------------------------------------------------------------------------------------------------------------------------------\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "556afd48",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "556afd48",
    "outputId": "7f0b9b03-7a98-4a01-ac78-fc754863ff46"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\2387417581.py:13: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\utils\\tensor_new.cpp:607.)\n",
      "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is available\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                           | 0/10 [00:00<?, ?it/s]C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\2463256959.py:51: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  matrix = torch.tensor(product)\n",
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\2463256959.py:69: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  V = [torch.tensor(eigenvector) for eigenvector in non_zero_eigenvectors]\n",
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\2463256959.py:33: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\Copy.cpp:299.)\n",
      "  U = torch.stack(update_V(C, theta)).double()\n",
      "C:\\Users\\Sandeep\\AppData\\Local\\Temp\\ipykernel_36368\\2463256959.py:144: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  matrix = torch.tensor(product)\n",
      " 40%|████████████████████████████████▊                                                 | 4/10 [09:53<17:24, 174.15s/it]"
     ]
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pylab as plt\n",
    "        # sns.heatmap(C_0.T@C_0)\n",
    "import time\n",
    "        \n",
    "highest_accuracy=0\n",
    "lambda_param = 0.001\n",
    "#0.0001,0.0001,10,0.0001\n",
    "for alpha_param in [1000,100,10,1,0.1,0.01,0.001,0.0001]:\n",
    "  for beta_param in [1000,100,10,1,0.1,0.01,0.001,0.0001]:\n",
    "      for gamma_param in [1000,100,10,1,0.1,0.01,0.001,0.0001]:\n",
    "            \n",
    "        av = []\n",
    "        for _ in range(2):\n",
    "            \n",
    "            avg_accuracy_all=[]\n",
    "            X=X\n",
    "            for _ in range(1):\n",
    "              C = random(p, k, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "          \n",
    "              theta = theta\n",
    "              s_c = time.time()\n",
    "              C_0 = experiment_structure(alpha_param,lambda_param,beta_param,gamma_param,C,theta,X,A)\n",
    "              e_c = time.time()\n",
    "              L = theta\n",
    "             \n",
    "              pseudo_C = torch.linalg.pinv(C_0)\n",
    "              X_t_0 = pseudo_C@X\n",
    "              C_test = C_0.cpu().detach().numpy()\n",
    "              X_t_test = X_t_0.cpu().detach().numpy()\n",
    "              L_test = L.cpu().detach().numpy() \n",
    "              g_s = time.time()\n",
    "              acc = get_accu(C_test,L_test,X_t_test)\n",
    "              g_e = time.time()\n",
    "\n",
    "              # print(\"Time taken:\",e_c-s_c + g_e - g_s)\n",
    "              av.append(acc)\n",
    "              if highest_accuracy<acc:\n",
    "                highest_accuracy=acc\n",
    "                print(\"Accuracy = \" + str(acc) + \" \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))\n",
    "        print(\"Average accuracy = \" + str(np.mean(av)*100)  + \" +/- \" + str(np.std(av)*100))\n",
    "        print(\"Params =  \" + str(alpha_param)+\" \" + str(beta_param)+\" \"+str(gamma_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18fd2468",
   "metadata": {},
   "outputs": [],
   "source": [
    "highest_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "VYrR83c2grMy",
   "metadata": {
    "id": "VYrR83c2grMy"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from scipy.sparse import coo_matrix\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "\n",
    "def largest_connected_component(laplacian):\n",
    "    # Convert Laplacian to NetworkX graph\n",
    "    coo_laplacian = coo_matrix(laplacian.cpu().numpy())\n",
    "    G = nx.from_scipy_sparse_matrix(coo_laplacian)\n",
    "    \n",
    "    # Find largest connected component\n",
    "    largest_cc = max(nx.connected_components(G), key=len)\n",
    "    subgraph = G.subgraph(largest_cc)\n",
    "    \n",
    "    # Convert subgraph to Laplacian\n",
    "    subgraph_laplacian = torch.tensor(nx.laplacian_matrix(subgraph).todense())\n",
    "    \n",
    "    return subgraph_laplacian\n",
    "\n",
    "def graph_metrics(laplacian):\n",
    "    # Get largest connected component\n",
    "    largest_cc_laplacian = largest_connected_component(laplacian)\n",
    "    \n",
    "    # Convert Laplacian to NetworkX graph\n",
    "    coo_laplacian = coo_matrix(largest_cc_laplacian.cpu().numpy())\n",
    "    G = nx.from_scipy_sparse_matrix(coo_laplacian)\n",
    "    \n",
    "    # Calculate metrics\n",
    "    num_nodes = G.number_of_nodes()\n",
    "    num_edges = G.number_of_edges()\n",
    "    avg_degree = num_edges / num_nodes\n",
    "    density = num_edges / (num_nodes * (num_nodes - 1))\n",
    "    diameter = nx.diameter(G)\n",
    "    avg_shortest_path = nx.average_shortest_path_length(G)\n",
    "    # avg_ricci_curvature = np.mean(list(nx.ricci_curvature(G).values()))\n",
    "    \n",
    "    return avg_degree, density, diameter, avg_shortest_path, 0\n",
    "\n",
    "# Example usage\n",
    "laplacian = theta\n",
    "avg_degree, density, diameter, avg_shortest_path, avg_ricci_curvature = graph_metrics(laplacian)\n",
    "print(f\"Average degree: {avg_degree}\")\n",
    "print(f\"Density: {density}\")\n",
    "print(f\"Diameter: {diameter}\")\n",
    "print(f\"Average shortest path: {avg_shortest_path}\")\n",
    "print(f\"Average Ricci curvature: {avg_ricci_curvature}\")\n",
    "highest_accuracy*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d64fbc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
