{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22a5600e-428e-4c05-b8c6-8ebb3374b1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.io as sio\n",
    "\n",
    "import torch\n",
    "print(torch.__version__)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "# from sklearn.decomposition import FactorAnalysis\n",
    "import random\n",
    "import os\n",
    "import collections\n",
    "import scipy.sparse as sp\n",
    "from torch import Tensor\n",
    "import torch_geometric\n",
    "# from torch_geometric.utils import to_networkx\n",
    "from torch_geometric.datasets import WikipediaNetwork\n",
    "from torch_geometric.datasets import WebKB\n",
    "from scipy.stats import rv_continuous\n",
    "from torch_geometric.utils import to_dense_adj\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_dir = \"./data\"\n",
    "os.makedirs(data_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0f5120d-1832-4798-8ebc-849424889f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = WebKB(root='/cornell',name='Cornell')\n",
    "\n",
    "# dataset = WebKB(root='/texas',name='Texas')\n",
    "# dataset = WebKB(root='/wisconsin',name='Wisconsin')\n",
    "# dataset = WikipediaNetwork(root='/squirrel',name='Squirrel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e41f5b-976e-4588-84d7-883fd745669b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = WikipediaNetwork(root='/Chameleon',name='Chameleon')\n",
    "print(dataset[0])\n",
    "edge_index = dataset[0].edge_index\n",
    "adj = to_dense_adj(edge_index)\n",
    "adj = adj[0]\n",
    "\n",
    "labels = dataset[0].y\n",
    "labels = labels.numpy()\n",
    "\n",
    "X = dataset[0].x\n",
    "X = X.to_dense()\n",
    "p = X.shape[0]\n",
    "NO_OF_CLASSES =  len(set(np.array(dataset[0].y)))\n",
    "NO_OF_CLASSES=len(set(labels))\n",
    "NO_OF_NODES=len(labels)\n",
    "\n",
    "total_indices1 = torch.arange(NO_OF_NODES)\n",
    "shuffled_indices1 = torch.randperm(total_indices1.numel())\n",
    "train_mask = int(0.1 * total_indices1.numel())\n",
    "print(\"Train Mask\", train_mask)\n",
    "test_mask = int(0.1 * total_indices1.numel())\n",
    "print(\"Test Mask\", test_mask)\n",
    "\n",
    "print(X.shape, adj.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0cae803-9b3c-46f8-b7a6-7a077e4b6b47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_laplacian(adj):\n",
    "    b=torch.ones(adj.shape[0])\n",
    "    return torch.diag(adj@b)-adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edd99e3c-3170-4451-83e4-394f372d9c55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convertScipyToTensor(coo):\n",
    "  try:\n",
    "    coo = coo.tocoo()\n",
    "  except:\n",
    "    coo = coo\n",
    "  values = coo.data\n",
    "  indices = np.vstack((coo.row, coo.col))\n",
    "\n",
    "  i = torch.LongTensor(indices)\n",
    "  v = torch.FloatTensor(values)\n",
    "  shape = coo.shape\n",
    "\n",
    "  return torch.sparse.FloatTensor(i, v, torch.Size(shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d67018de-b1d6-4d8a-8085-ddedc2dad4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def L_operator(w,k):\n",
    "  Lw = torch.zeros(k,k)\n",
    "  start = 0\n",
    "  for i in range(k):\n",
    "    end = start+k-i-1\n",
    "    Lw[i+1:,i] = -torch.reshape(w[start:end], [-1])\n",
    "    start = end\n",
    "  Lw = Lw+torch.transpose(Lw, 0, 1)\n",
    "  for i in range(k):\n",
    "    Lw[i,i] = -1*Lw[i,:].sum()\n",
    "  return Lw\n",
    "def L_inv_operator(Lw,p1):\n",
    "  k = Lw.shape[0]\n",
    "  diag = torch.diagonal(Lw)\n",
    "  w = torch.tensor(())\n",
    "  for i in range(k):\n",
    "    w = torch.cat((w, -Lw[i+1:,i]- Lw[i,i+1:]+diag[i+1:]+diag[i]*torch.ones((k-i-1))), 0)\n",
    "  return torch.reshape(w, (p1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "034a5720-8d31-4624-9c91-a3aa9e5bf353",
   "metadata": {},
   "outputs": [],
   "source": [
    "def CGL(alpha_param,beta_param,gamma_param,delta_param,eta_param,lambda_param,w,C,X_tilde,X,r,epoch):\n",
    "      p = X.shape[0]\n",
    "      k = int(p*r)\n",
    "      p1 = int((k)*(k-1)/2)\n",
    "      n = X.shape[1]\n",
    "\n",
    "      ones = csr_matrix(np.ones((k,k)))\n",
    "      ones = convertScipyToTensor(ones)\n",
    "      ones = ones.to_dense()\n",
    "\n",
    "      J = np.outer(np.ones(k), np.ones(k))/k\n",
    "      J = csr_matrix(J)\n",
    "      J = convertScipyToTensor(J)\n",
    "      J = J.to_dense()\n",
    "\n",
    "      zeros = csr_matrix(np.zeros((p1,1)))\n",
    "      zeros = convertScipyToTensor(zeros)\n",
    "      zeros = zeros.to_dense()\n",
    "\n",
    "      X_tilde = convertScipyToTensor(X_tilde)\n",
    "      X_tilde = X_tilde.to_dense()\n",
    "\n",
    "      C = convertScipyToTensor(C)\n",
    "      C = C.to_dense()\n",
    "\n",
    "      w = convertScipyToTensor(w)\n",
    "      w = w.to_dense()\n",
    "\n",
    "      eye = torch.eye(k)\n",
    "        \n",
    "      def one_hot(x, class_count):\n",
    "        return torch.eye(class_count)[x, :]  \n",
    "        \n",
    "      P = labels\n",
    "      P = one_hot(P,NO_OF_CLASSES)\n",
    "      P[train_mask, :] = 0 \n",
    "\n",
    "\n",
    "      try:\n",
    "        X = convertScipyToTensor(X)\n",
    "        X = X.to_dense()\n",
    "      except:\n",
    "        X = X\n",
    "\n",
    "      if(torch.cuda.is_available()):\n",
    "        X_tilde = X_tilde.cuda()\n",
    "        C = C.cuda()\n",
    "        w = w.cuda()\n",
    "        X = X.cuda()\n",
    "        J = J.cuda()\n",
    "        zeros = zeros.cuda()\n",
    "        ones = ones.cuda()\n",
    "        eye = eye.cuda()\n",
    "\n",
    "      def update(w,X_tilde,C):\n",
    "          Lw = L_operator(w,k)\n",
    "          # Updating C\n",
    "          CX = C@X\n",
    "          XT = torch.transpose(X,0,1)\n",
    "          num1 = alpha_param*CX@XT\n",
    "          num2 = delta_param*C\n",
    "          num3 = lambda_param*C@P@torch.transpose(P,0,1)\n",
    "          num = num1+num2+num3\n",
    "          X_tildeXT = X_tilde@XT\n",
    "          CT = torch.transpose(C,0,1)\n",
    "          CCT = C@CT\n",
    "          den1 = alpha_param*X_tildeXT\n",
    "          den2 = delta_param*CCT@C\n",
    "          den = den1 + den2\n",
    "          temp = torch.div(num,den)\n",
    "          Cnew = torch.mul(C,temp)\n",
    "          # Updating X_tilde\n",
    "          t1 = torch.linalg.pinv(2*Lw + (alpha_param + eta_param)*eye)\n",
    "          X_tilde_new = alpha_param*t1@CX\n",
    "          # Updating w\n",
    "          t2 = torch.linalg.pinv(Lw + J)\n",
    "          X_tildeT = torch.transpose(X_tilde,0,1)\n",
    "          t0 = (-1/beta_param)*(X_tilde@X_tildeT)\n",
    "          c = L_inv_operator(t0,p1)\n",
    "          t3 = -gamma_param*L_inv_operator(t2,p1) + L_inv_operator(Lw,p1) - c\n",
    "          # print(t3.shape)\n",
    "          if(torch.cuda.is_available()):\n",
    "            t3 = t3.cuda()\n",
    "          w_new = (w-(t3/1)).maximum(zeros)\n",
    "\n",
    "\n",
    "          Cnew[Cnew<thresh] = thresh\n",
    "          for i in range(len(Cnew)):\n",
    "              Cnew[i] = Cnew[i]/torch.linalg.norm(Cnew[i],1)\n",
    "          for i in range(len(X_tilde_new)):\n",
    "            X_tilde_new[i] = X_tilde_new[i]/torch.linalg.norm(X_tilde_new[i],1)\n",
    "          return w_new,X_tilde_new,Cnew\n",
    "\n",
    "\n",
    "      for i in range(epoch):\n",
    "          w,X_tilde,C = update(w,X_tilde,C)\n",
    "      C[C<=thresh] = 0\n",
    "\n",
    "      return w,X_tilde,C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9617c840-77f4-4d33-9b54-f31228f94f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import random\n",
    "class CustomDistribution(rv_continuous):\n",
    "    def _rvs(self,  size=None, random_state=None):\n",
    "        return random_state.standard_normal(size)\n",
    "temp = CustomDistribution(seed=1)\n",
    "temp2 = temp()  # get a frozen version of the distribution\n",
    "p=X.shape[0]\n",
    "n=X.shape[1]\n",
    "thresh = 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047e010a-3e28-4d43-b7c0-d447fe17ca9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(X.shape[1], 64)\n",
    "        self.conv2 = GCNConv(64, NO_OF_CLASSES)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.conv1.reset_parameters()\n",
    "        self.conv2.reset_parameters()\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "\n",
    "        try:\n",
    "          x = x.to('cpu')\n",
    "        except:\n",
    "          pass\n",
    "        try:\n",
    "          edge_index = edge_index.to('cpu')\n",
    "        except:\n",
    "          pass\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b93dc8d-0b5f-43e8-89ef-f6d9c86f3f63",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net().to('cpu')\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "277c9d87-56f3-40ce-8121-8f28f1352785",
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import sample\n",
    "def get_accuracy(C,w,Xc,L, num_coarscened_rows):\n",
    "    global labels, NO_OF_CLASSES,X\n",
    "    k = num_coarscened_rows\n",
    "    t=[]\n",
    "    for i in [1]:#,2,3,4,5,6,7,8,9,10]:\n",
    "        from scipy import sparse\n",
    "        Lc = L_operator(w,k)\n",
    "        Wc=(-1*Lc)*(1-np.eye(Lc.shape[0]))\n",
    "        Wc[Wc<0.1]=0\n",
    "        Wc=sparse.csr_matrix(Wc)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen2 = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        def one_hot(x, class_count):\n",
    "            return torch.eye(class_count)[x, :]\n",
    "\n",
    "        device = torch.device('cpu')\n",
    "        labels=labels\n",
    "        Y = labels\n",
    "        Y = one_hot(Y,NO_OF_CLASSES)\n",
    "        Y[train_mask, :] = 0\n",
    "        labels_coarse = torch.argmax(torch.sparse.mm(torch.Tensor(C).double() , Y.double()).double() , 1)\n",
    "        Wc=Wc.toarray()\n",
    "        model = Net().to('cpu')\n",
    "        lr=0.1\n",
    "        decay=0.0001\n",
    "        try:\n",
    "          X=np.array(X.todense())\n",
    "        except:\n",
    "          X = np.array(X)\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)\n",
    "        x=sample(range(0, int(k)), k)\n",
    "        from datetime import datetime\n",
    "        Xt=torch.Tensor(C@X)\n",
    "\n",
    "        def train():\n",
    "            model.train()\n",
    "            optimizer.zero_grad()\n",
    "            out = model(Xt,edge_index_coarsen2.to(\"cpu\"))\n",
    "\n",
    "            loss = F.nll_loss(out[x], labels_coarse[x])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            return loss\n",
    "        now1 = datetime.now()\n",
    "        losses=[]\n",
    "        for epoch in range(200):\n",
    "            loss=train()\n",
    "            losses.append(loss)\n",
    "        now2 = datetime.now()\n",
    "        pred=model(Xt,edge_index_coarsen2).argmax(dim=1)\n",
    "        def train_accuracy():\n",
    "            model.eval()\n",
    "            correct = (pred[x] == labels_coarse[x]).sum()\n",
    "            acc = int(correct) /len(x)\n",
    "            return acc\n",
    "        ac=train_accuracy()\n",
    "        t+=[(now2-now1).total_seconds()]\n",
    "\n",
    "        zz=sample(range(0, int(NO_OF_NODES)), NO_OF_NODES)\n",
    "        global adj\n",
    "        try:\n",
    "          adj = adj.detach().cpu().numpy()\n",
    "        except:\n",
    "          try:\n",
    "            adj = adj.numpy()\n",
    "          except:\n",
    "            adj = adj\n",
    "        Wc=sparse.csr_matrix(adj)\n",
    "        Wc = Wc.tocoo()\n",
    "        row = torch.from_numpy(Wc.row).to(torch.long)\n",
    "        col = torch.from_numpy(Wc.col).to(torch.long)\n",
    "        edge_index_coarsen = torch.stack([row, col], dim=0)\n",
    "        edge_weight = torch.from_numpy(Wc.data)\n",
    "        pred=model(torch.Tensor(X),edge_index_coarsen).argmax(dim=1)\n",
    "        pred=np.array(pred)\n",
    "        correct =(pred[zz]==labels[zz]).sum()\n",
    "        acc = int(correct) /NO_OF_NODES\n",
    "\n",
    "        return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b4aa16e-a080-463c-a471-02493ed4b688",
   "metadata": {},
   "outputs": [],
   "source": [
    "L = Tensor(get_laplacian(adj))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de30239-4893-42ba-88df-9f9595279d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from scipy.sparse import csr_matrix\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0994b59d-f4f8-4a54-831e-3fa99ea6facf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cgl_node_classification(coarsening_ratio):\n",
    "    global temp2,p,n,X, coarsening_ratios_accuracy_achieved\n",
    "    epoch = 10\n",
    "    r=coarsening_ratio\n",
    "    k=int(p*r)\n",
    "    p1=int((k)*(k-1)/2)\n",
    "    X_tilde_initial = random(k, n, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "    C_initial = random(k, p, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "    w_initial = random(p1, 1, density=0.15, random_state=1, data_rvs=temp2.rvs)\n",
    "    T = 0\n",
    "    alpha=0\n",
    "    beta=0\n",
    "    gama=0\n",
    "    delta=0\n",
    "    print('\\033[1m' + \"CGL accuracy for coarsening_ratio:\", coarsening_ratio);\n",
    "    print('\\033[0m')\n",
    "    def run_loop():\n",
    "        global X\n",
    "        tmp = 0\n",
    "        highest_accuracy = 0\n",
    "        for e in [1e-5,1e-4,1e-3,1e-1,1,10,1e2,1e3,1e4]:\n",
    "            for d in [1e-5,1e-4,1e-3,1e-1,1,10,1e2,1e3,1e4]:\n",
    "              for c in [1e-5,1e-4,1e-3,1e-1,1,10,1e2,1e3,1e4]:\n",
    "                for b in [1e-5,1e-4,1e-3,1e-1,1,10,1e2,1e3,1e4]:\n",
    "                  for a in [1e-5,1e-4,1e-3,1e-1,1,10,100,1e3,1e4]:\n",
    "                    try:\n",
    "                          X=torch.Tensor(X)\n",
    "                          w,X_tilde,C = CGL(a,b,c,d,0,e,w_initial,C_initial,X_tilde_initial,X,r,epoch)\n",
    "                          acc = get_accuracy(C,w,X_tilde,L,k);\n",
    "                          if(tmp%100 == 0):\n",
    "                              print(\"a: \" + str(a)+\" \"+\"b: \" + str(b)+\" \"+\"c: \"+str(c)+\" \"+\"d: \"+str(d)+\" e: \"+str(e)) \n",
    "                          tmp = tmp+1;      \n",
    "                          if highest_accuracy<acc:\n",
    "                            highest_accuracy=acc\n",
    "                            print(\"Highest Accuracy = \" + str(acc) + \" parameters= a: \" + str(a)+\" b: \" + str(b)+\" c: \"+str(c)+\" d: \"+str(d)+\" e: \"+str(e))\n",
    "                    except KeyboardInterrupt:\n",
    "                        print ('KeyboardInterrupt exception is caught')\n",
    "                        raise\n",
    "                    except:\n",
    "                        print(\"Error Occured\", a,b,c,d,e);\n",
    "                    if (highest_accuracy > coarsening_ratios_accuracy_achieved[coarsening_ratio]):\n",
    "                        print('\\033[1m')\n",
    "                        print(\"Highest Accuracy = \" + str(acc) + \" parameters= a: \" + str(a)+\" b: \" + str(b)+\" c: \"+str(c)+\" d: \"+str(d)+\" e: \"+str(e))\n",
    "                        print('\\033[0m')\n",
    "                        return;\n",
    "    run_loop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb19674-87b5-4238-95ec-816f334f5130",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "coarsening_methods = ['CGL'];\n",
    "coarsening_ratios = [0.7 ];\n",
    "coarsening_ratios_accuracy_achieved = {\n",
    "  0.3: 0.96,\n",
    "   0.5: 0.96,\n",
    "  0.7: 0.60\n",
    "}\n",
    "\n",
    "for coarsening_method in coarsening_methods:\n",
    "    for coarsening_ratio in coarsening_ratios:\n",
    "        if (coarsening_method == 'CGL'):\n",
    "            cgl_node_classification(coarsening_ratio);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d1cfba2-7062-4d1f-b550-28ab60f832ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ebe5526-3703-4d7d-b642-c102c2963685",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
