{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch as pt\n",
    "import torch_geometric as ptg\n",
    "import pytorch_lightning as ptl\n",
    "from torch_geometric.typing import (\n",
    "    Adj,\n",
    "    OptPairTensor,\n",
    "    OptTensor,\n",
    "    SparseTensor,\n",
    "    torch_sparse,\n",
    ")\n",
    "\n",
    "from torch.nn import Linear, Sequential, ReLU, BatchNorm1d, ModuleList, Sigmoid, Tanh, ELU, Identity, Parameter, LeakyReLU\n",
    "from torch.optim import Adam, AdamW\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.optim.swa_utils import AveragedModel, SWALR\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "from torch_geometric.nn import GINConv, GCNConv, GATConv, global_mean_pool, global_add_pool\n",
    "from torch_geometric.nn.norm import BatchNorm, GraphNorm, PairNorm, DiffGroupNorm\n",
    "from torch_geometric.transforms import BaseTransform\n",
    "from torch_geometric.utils import to_torch_coo_tensor, scatter\n",
    "from torch_geometric.utils import to_torch_coo_tensor, scatter, to_scipy_sparse_matrix\n",
    "\n",
    "from torchmetrics import Accuracy\n",
    "\n",
    "from pytorch_lightning import Trainer\n",
    "from lightning.pytorch.callbacks import LearningRateFinder\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from scipy.sparse import diags, csr_matrix, eye\n",
    "from scipy.sparse.linalg import eigs, eigsh\n",
    "from scipy.linalg import pinv, pinvh\n",
    "\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "sns.set_theme()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AppendEVs(BaseTransform):\n",
    "    def __init__(self, num_evs, make_undirected=False, resize_y=False):\n",
    "        self.num_evs = num_evs -1\n",
    "        self.v = pt.randn((10000, ))\n",
    "        self.v = self.v * pt.sign(pt.sum(self.v[0]))\n",
    "        self.i = 0\n",
    "        self.make_undirected = make_undirected\n",
    "        self.resize_y = resize_y\n",
    "        self.first_EVs = None\n",
    "        \n",
    "    def __call__(self, data):\n",
    "        self.i = self.i + 1\n",
    "        if self.make_undirected:\n",
    "            data.edge_index = ptg.utils.to_undirected(data.edge_index)\n",
    "            \n",
    "        if self.resize_y:\n",
    "            data.y = data.y.reshape((-1,))\n",
    "            \n",
    "        topology = to_scipy_sparse_matrix(data.edge_index)\n",
    "        topology = topology.todense()\n",
    "        topology = np.concatenate((topology, 1e-5 * np.ones((topology.shape[0],1))), axis=1)\n",
    "        topology = np.concatenate((topology, 1e-5 * np.ones((1,topology.shape[1]))), axis=0)\n",
    "        topology = csr_matrix(topology, dtype=np.float32)\n",
    "\n",
    "        gin_conv = deepcopy(topology)\n",
    "        \n",
    "        gcn_conv = eye(topology.shape[0]) + deepcopy(topology)\n",
    "        d = gcn_conv @ np.ones(gcn_conv.shape[1])\n",
    "        d[d == 0] = 1\n",
    "        d = 1/np.sqrt(d)\n",
    "        gcn_conv = diags(d) @ gcn_conv @ diags(d)\n",
    "\n",
    "        gat_conv = deepcopy(topology)\n",
    "        d = gat_conv @ np.ones(gat_conv.shape[1])\n",
    "        d = 1/d\n",
    "        gat_conv = diags(d) @ gat_conv\n",
    "        \n",
    "        gnn_operators = [gin_conv, gcn_conv, gat_conv]\n",
    "        \n",
    "        for idx, op in enumerate(gnn_operators):\n",
    "            \n",
    "            if idx == 0 or idx == 1:\n",
    "                evalues, evectors = eigs(op, k=min(self.num_evs, op.shape[0]-2), which='LM', tol=1e-8, maxiter=100000)\n",
    "                evectors = evectors[:, np.argsort(evalues.real)].real\n",
    "            \n",
    "                evectors = pt.tensor(evectors, dtype=pt.float32)\n",
    "\n",
    "                evectors = evectors[:-1] * pt.sign(evectors[-1])\n",
    "\n",
    "                ones = pt.ones(evectors.shape[0])/np.sqrt(evectors.shape[0])\n",
    "                evectors = ((evectors.transpose(0,1) @ ones) + 1e-4) * evectors\n",
    "                rest = ones - pt.sum(evectors, dim=1)\n",
    "                if idx == 0:\n",
    "                    data.gin_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs - evectors.shape[1]))), rest.reshape(-1,1)), dim=1)\n",
    "                elif idx == 1:\n",
    "                    data.gcn_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs - evectors.shape[1]))), rest.reshape(-1,1)), dim=1)\n",
    "            elif idx == 2:\n",
    "                evalues, evectors = eigs(op, k=min(self.num_evs+1, op.shape[0]-2), which='LM', tol=1e-8,  maxiter=100000)\n",
    "                evectors = evectors[:, np.argsort(evalues.real)].real\n",
    "                \n",
    "                evectors = pt.tensor(evectors, dtype=pt.float32)\n",
    "                \n",
    "                evectors = evectors[:-1] * pt.sign(evectors[-1])\n",
    "                    \n",
    "                data.gat_EVs = pt.cat((evectors, pt.zeros((evectors.shape[0], (self.num_evs +1 - evectors.shape[1])))), dim=1)\n",
    "\n",
    "\n",
    "\n",
    "        return data\n",
    "\n",
    "\n",
    "\n",
    "class GNN_Module_with_metrics(ptl.LightningModule):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "\n",
    "        # hyperparameters\n",
    "        self.gnn_conv = config['gnn_conv']\n",
    "        self.embedding_size = config['embedding_size']\n",
    "        self.num_layers = config['num_layers']\n",
    "        if config['initialization'] == 'normal':\n",
    "            self.initialization = lambda x : pt.nn.init.normal_(x, mean=0, std=1)\n",
    "        elif config['initialization'] == 'constant':\n",
    "            self.initialization = lambda x : pt.nn.init.constant_(x, val=1)\n",
    "        self.learning_rate = config['learning_rate']\n",
    "        self.regularization = config['regularization']\n",
    "        self.num_inputs = config['num_inputs']\n",
    "        self.num_outputs = config['num_outputs']\n",
    "        self.num_evs = config['num_evs']\n",
    "        self.graph_Level = config['graph_level']\n",
    "        self.loss = pt.nn.CrossEntropyLoss() #self.batch_cost\n",
    "        self.out_function = pt.nn.Softmax(dim=1)\n",
    "        if config['activation_func'] == 'relu':\n",
    "            self.relu = ReLU()\n",
    "        elif config['activation_func'] == 'leakyrelu':\n",
    "            self.relu = LeakyReLU()\n",
    "        elif config['activation_func'] == 'id':\n",
    "            self.relu = Identity()\n",
    "        else:\n",
    "            raise ValueError('Activation function not supported.')    \n",
    "        self.dropout = pt.nn.Dropout(config['dropout'])\n",
    "        \n",
    "        self.graphnorm_lambda = Parameter(pt.tensor(0.5, device=self.device, dtype=pt.float32))\n",
    "        if self.gnn_conv == 'GIN':\n",
    "            self.graphnorm2_lambda = Parameter(pt.tensor(([[0.0]] * (3)) + ([[1.0]] * (self.num_evs - 3)), device=self.device, dtype=pt.float32, requires_grad=True))\n",
    "            self.EV_key = 'gin_EVs'\n",
    "        elif self.gnn_conv == 'GCN':\n",
    "            self.graphnorm2_lambda = Parameter(pt.tensor(([[0.0]] * (3)) + ([[1.0]] * (self.num_evs - 3)), device=self.device, dtype=pt.float32, requires_grad=True))\n",
    "            self.EV_key = 'gcn_EVs'\n",
    "        elif self.gnn_conv == 'GAT':\n",
    "            self.graphnorm2_lambda = Parameter(pt.randn((self.num_evs, 1), device=self.device, dtype=pt.float32, requires_grad=False))\n",
    "            self.EV_key = 'gat_EVs'\n",
    "        else:\n",
    "            raise ValueError('GNN type not supported.')\n",
    "            \n",
    "        self.norm_gamma = Parameter(pt.tensor(1.0, device=self.device, dtype=pt.float32))\n",
    "        self.norm_beta = Parameter(pt.tensor(0.0, device=self.device, dtype=pt.float32))\n",
    "        \n",
    "        \n",
    "        if config['normalization'] == 'batch':\n",
    "            self.normalization = self.batch_normalize \n",
    "        elif config['normalization'] == 'powerembed':\n",
    "            self.normalization = self.batch_without_mean\n",
    "        elif config['normalization'] == 'graph':\n",
    "            self.normalization = self.graphnorm \n",
    "        elif config['normalization'] == 'graph2':\n",
    "            self.normalization = self.graphnorm2\n",
    "        elif config['normalization'] == 'pair':\n",
    "            p_n = PairNorm()\n",
    "            self.normalization = lambda x,y : p_n(x, y.batch)\n",
    "        elif config['normalization'] == 'group':\n",
    "            g_n = DiffGroupNorm(self.embedding_size, self.num_outputs)\n",
    "            self.normalization = lambda x,y : g_n(x)\n",
    "        elif config['normalization'] == 'none':\n",
    "            self.normalization = lambda x,y : x/(pt.linalg.norm(x) + 1e-7)\n",
    "        elif config['normalization'] == 'res_con': \n",
    "            self.alpha = Parameter(pt.tensor(0.3, device=self.device, dtype=pt.float32))\n",
    "            W_1 = pt.nn.init.normal_(pt.empty((self.embedding_size, self.embedding_size), device=self.device), 0, 0.04)\n",
    "            W_2 = pt.nn.init.normal_(pt.empty((self.embedding_size, self.embedding_size), device=self.device), 0, 0.04)\n",
    "            self.normalization = lambda x,y : self.alpha * x @ W_1 + (1-self.alpha) * self.encoder(y.x) @ W_2\n",
    "        else:\n",
    "            self.normalization = lambda x,y : x\n",
    "            \n",
    "        \n",
    "        if self.gnn_conv == 'GCN':\n",
    "            self.convs = ModuleList([GCNConv(self.embedding_size, self.embedding_size, cached=True, normalize=False, bias=False, add_self_loops=False) for _ in range(self.num_layers)])\n",
    "        elif self.gnn_conv == 'GIN':\n",
    "            self.convs = ModuleList([GINConv(Sequential(Linear(self.embedding_size, self.embedding_size, bias=False),), aggr='mean') for _ in range(self.num_layers)])\n",
    "        elif self.gnn_conv == 'GAT':\n",
    "            self.convs = ModuleList([GATConv(self.embedding_size, self.embedding_size, heads=8, concat=False, cached=False, bias=False, aggr='mean') for _ in range(self.num_layers)])\n",
    "        else:\n",
    "            raise ValueError('GNN type not supported.')\n",
    "        self.encoder = Linear(self.num_inputs, self.embedding_size, bias=False, device=self.device)\n",
    "        self.decoder = Linear(self.embedding_size, self.num_outputs, bias=False, device=self.device)\n",
    "        \n",
    "        self.reset_parameters()\n",
    "\n",
    "        # metrics\n",
    "        self.save_hyperparameters(config, logger=False)\n",
    "        self.train_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)\n",
    "        self.val_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)\n",
    "        self.test_accuracy = Accuracy(task='multiclass', num_classes=self.num_outputs)\n",
    "        \n",
    "        self.embeddings = []\n",
    "\n",
    "\n",
    "    def forward(self, data):\n",
    "        self.embeddings = []\n",
    "        embeddings = pt.empty((data.num_nodes, self.embedding_size), device=self.device)\n",
    "        \n",
    "\n",
    "        self.initialization(embeddings)\n",
    "        embeddings.requires_grad = True\n",
    "        \n",
    "        \n",
    "        \n",
    "        for it in range(self.num_layers):\n",
    "            embeddings = self.convs[it](embeddings, data.edge_index)\n",
    "            embeddings = self.relu(embeddings)\n",
    "            embeddings = self.normalization(embeddings, data)\n",
    "            self.embeddings.append(embeddings)\n",
    "            \n",
    "        return self.decoder(embeddings if not self.graph_Level else global_add_pool(embeddings, data.batch))\n",
    "        \n",
    "\n",
    "    def training_step(self, batch):\n",
    "        # Forward pass\n",
    "        prediction = self(batch)\n",
    "        # Compute loss\n",
    "        loss = self.loss(prediction, batch.y)\n",
    "        # Logging\n",
    "        self.log('train_loss', loss, on_epoch=True, prog_bar=True)\n",
    "        self.log('train_acc', self.train_accuracy(self.out_function(prediction), batch.y), on_epoch=True, prog_bar=True)\n",
    "        # Loss is passed to backward pass\n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch):\n",
    "        # Forward pass\n",
    "        prediction = self(batch)\n",
    "        # Compute loss\n",
    "        loss = self.loss(prediction, batch.y)\n",
    "        acc = self.val_accuracy(self.out_function(prediction), batch.y)\n",
    "        score = pt.nn.functional.mse_loss(self.out_function(prediction), pt.nn.functional.one_hot(batch.y, self.num_outputs).float())\n",
    "        # Logging\n",
    "        self.log('val_loss', loss, on_epoch=True, prog_bar=True)\n",
    "        self.log('val_acc', acc, on_epoch=True, prog_bar=True)\n",
    "        self.log('val_score', score, on_epoch=True, prog_bar=True)\n",
    "        \n",
    "        \n",
    "\n",
    "    \n",
    "    def test_step(self, batch):\n",
    "        # Forward pass\n",
    "        prediction = self(batch)\n",
    "        # Compute loss\n",
    "        loss = self.loss(prediction, batch.y)\n",
    "        # Logging\n",
    "        self.log('test_loss', loss)\n",
    "        self.log('test_acc', self.test_accuracy(self.out_function(prediction), batch.y))\n",
    "\n",
    "    \n",
    "    \n",
    "    def batch_normalize(self, X, data):\n",
    "        mean = scatter(X, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                           reduce='mean')\n",
    "        \n",
    "        X = X - mean.index_select(0, data.batch)\n",
    "\n",
    "        var = scatter(X * X, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                          reduce='mean')\n",
    "        \n",
    "        return (X / (var + 1e-8).sqrt().index_select(0, data.batch)) * self.norm_gamma + self.norm_beta\n",
    "    \n",
    "    def batch_without_mean(self, X, data):\n",
    "\n",
    "        var = scatter(pt.abs(X), data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                          reduce='max')\n",
    "        \n",
    "        return (X / (var + 1e-8).sqrt().index_select(0, data.batch)) * self.norm_gamma + self.norm_beta\n",
    "\n",
    "\n",
    "    def graphnorm(self, X, data):\n",
    "        mean = scatter(X, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                           reduce='mean')\n",
    "        \n",
    "        X = X - self.graphnorm_lambda * mean.index_select(0, data.batch)\n",
    "\n",
    "        var = scatter(X * X, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                          reduce='mean')\n",
    "        \n",
    "        return (X / (var + 1e-8).sqrt().index_select(0, data.batch)) * self.norm_gamma + self.norm_beta\n",
    "    \n",
    "    \n",
    "    \n",
    "    #not instance norm!\n",
    "    def graphnorm2(self, X, data):\n",
    "       \n",
    "        projector = data[self.EV_key] @ self.graphnorm2_lambda\n",
    "\n",
    "       \n",
    "        mean = scatter(X * projector, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                           reduce='mean')\n",
    "        \n",
    "        X =  X - mean.index_select(0, data.batch)\n",
    "        \n",
    "\n",
    "        var = scatter(X * X, data.batch, dim=0, dim_size=data.batch.shape[0],\n",
    "                          reduce='mean')\n",
    "        \n",
    "        return (X / (var + 1e-8).sqrt().index_select(0, data.batch)) * self.norm_gamma + self.norm_beta\n",
    "    \n",
    "    \n",
    "    def reset_parameters(self):\n",
    "        for comp in self.convs:\n",
    "            comp.reset_parameters()\n",
    "        self.encoder.reset_parameters()\n",
    "        self.decoder.reset_parameters()\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "       \n",
    "        optimizer = AdamW(self.parameters(),\n",
    "                          lr=self.learning_rate, \n",
    "                          weight_decay=self.regularization, \n",
    "                          fused=True)\n",
    "        return {'optimizer':optimizer,\n",
    "                'lr_scheduler': ReduceLROnPlateau(optimizer,\n",
    "                                                  factor=0.5, \n",
    "                                                  patience=50, \n",
    "                                                  verbose=True, \n",
    "                                                  mode='min'),\n",
    "                #'lr_scheduler': SWALR(optimizer, anneal_epochs=20, anneal_strategy='cos', swa_lr=self.learning_rate),\n",
    "                'monitor':'train_loss'}\n",
    "                \n",
    "\n",
    "    def __module_name__(self=None):\n",
    "        return 'GIN_Module'\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def colum_distance(X):\n",
    "    X = X / pt.linalg.vector_norm(X, dim=0, ord=1)\n",
    "    Y = X.transpose(0,1).reshape(1, X.shape[1], X.shape[0])\n",
    "    return pt.mean(pt.cdist(Y,Y))/2\n",
    "\n",
    "def column_projection_distance(X):\n",
    "    X = X / pt.linalg.vector_norm(X, dim=0, ord=2)\n",
    "    Q = (pt.ones((X.shape[1],X.shape[1]), device=X.device, dtype=pt.float64)) - pt.abs(X.transpose(0,1) @ X) \n",
    "    #print(Q)\n",
    "    return pt.mean(pt.abs(Q))\n",
    "\n",
    "def projection_cost(X, V, V_inv):\n",
    "    X = X / pt.linalg.matrix_norm(X, ord=2)\n",
    "    return 1/X.shape[0] * pt.norm(pt.abs(X - V @ (V_inv @ X)))\n",
    "\n",
    "def mu(X, v):\n",
    "    X = X / pt.linalg.matrix_norm(X, ord=2)\n",
    "    return pt.norm(X - pt.outer(v, v @ X), p='fro') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "normalizations = ['none', 'pair', 'batch', 'powerembed', 'graph',  'graph2','res_con',]\n",
    "num_evs = 4\n",
    "num_trials = 10\n",
    "num_layers = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "dataset = ptg.datasets.CitationFull(root='./datasets_ablation', name=\"Cora\", pre_transform=AppendEVs(num_evs, make_undirected=True))\n",
    "GNN_conv = 'GCN'\n",
    "model_config = {'embedding_size': 8,\n",
    "                                 'learning_rate': 1e-2,\n",
    "                                 'num_layers': num_layers, \n",
    "                                 'a_iteration': 1,\n",
    "                                 'dropout': 0,\n",
    "                                 'regularization': 0,\n",
    "                                 'num_classes':dataset.num_classes, \n",
    "                                 'dataset_num_classes':dataset.num_classes,\n",
    "                                 'num_features':dataset.num_features,\n",
    "                                 'initialization': 'normal', \n",
    "                                 'jk':None, \n",
    "                                 'num_evs':num_evs,\n",
    "                                 'activation_func': 'id', \n",
    "                                 'graph_level':True,\n",
    "                                 'num_outputs': dataset.num_classes,\n",
    "                                 'num_inputs': dataset.num_features,\n",
    "                                 'normalization':'undefined',\n",
    "                                 'gnn_conv':GNN_conv}\n",
    "\n",
    "V = dataset[0].gin_EVs\n",
    "V_pinv = pinv(V)\n",
    "V = pt.tensor(V.real, dtype=pt.float32)\n",
    "V_pinv = pt.tensor(V_pinv.real, dtype=pt.float32)\n",
    "\n",
    "\n",
    "# topology = pt.sparse_csr_tensor(*dataset[0].edge_index, 1)\n",
    "\n",
    "# d = topology @ pt.ones(dataset[0].num_nodes)\n",
    "# d_sqrt = pt.pow(d, -0.5)\n",
    "#d[d == 0] = 1\n",
    "#d = pt.tensor(d, dtype=pt.float32)\n",
    "#gcn_conv = diags(d) @ gcn_conv @ diags(d)\n",
    "#evalues, evectors = eigs(gcn_conv, k=1, which='LR', tol=1e-8, maxiter=100000)\n",
    "#v = pt.tensor(evectors.real, dtype=pt.float32)\n",
    "topology = to_scipy_sparse_matrix(dataset[0].edge_index)\n",
    "evalues, evectors = eigs(topology, k=5, which='LM', tol=1e-8, maxiter=100000)\n",
    "plt.plot(evectors[:,0].real)\n",
    "v = pt.tensor(evectors[:,0].real) #emb[:,0].detach().clone()\n",
    "\n",
    "costs = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for norm in normalizations:\n",
    "    costs[norm] = np.zeros((5, num_trials, num_layers))\n",
    "    for it in range(num_trials):\n",
    "        model_config['normalization'] = norm\n",
    "    \n",
    "        model = GNN_Module_with_metrics(model_config)\n",
    "        dl = ptg.loader.DataLoader(dataset, batch_size=1)\n",
    "    \n",
    "        trainer = Trainer(accelerator='cpu', precision=32)\n",
    "        trainer.predict(model, dl)\n",
    "        \n",
    "        \n",
    "\n",
    "        for i_emb,emb in enumerate(model.embeddings):\n",
    "            costs[norm][0][it][i_emb] = mu(emb, v).cpu().detach().numpy() \n",
    "            costs[norm][1][it][i_emb] = colum_distance(emb).cpu().detach().numpy()\n",
    "            costs[norm][2][it][i_emb] = projection_cost(emb, V.to(emb.device), V_pinv.to(emb.device)).cpu().detach().numpy()\n",
    "            costs[norm][3][it][i_emb] = column_projection_distance(emb).cpu().detach().numpy()\n",
    "            costs[norm][4][it][i_emb] = pt.linalg.matrix_rank(emb, atol=1e-1).cpu().detach().numpy()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(costs['none'][0].mean(axis=0), label='mu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_costs = {k:np.mean(v, axis=1) for k,v in costs.items()}\n",
    "std_costs = {k:np.std(v, axis=1) for k,v in costs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "costs_nl = {}\n",
    "for norm in normalizations:\n",
    "    costs_nl[norm] = np.zeros((5, num_trials, num_layers))\n",
    "    for it in range(num_trials):\n",
    "        model_config['normalization'] = norm\n",
    "        model_config['activation_func'] = 'leakyrelu'\n",
    "    \n",
    "        model = GNN_Module_with_metrics(model_config)\n",
    "        dl = ptg.loader.DataLoader(dataset, batch_size=1)\n",
    "    \n",
    "        trainer = Trainer(accelerator='cpu', precision=32)\n",
    "        trainer.predict(model, dl)\n",
    "        \n",
    "        \n",
    "\n",
    "        for i_emb,emb in enumerate(model.embeddings):\n",
    "            costs_nl[norm][0][it][i_emb] = mu(emb, v).cpu().detach().numpy() \n",
    "            costs_nl[norm][1][it][i_emb] = colum_distance(emb).cpu().detach().numpy()\n",
    "            costs_nl[norm][2][it][i_emb] = projection_cost(emb, V.to(emb.device), V_pinv.to(emb.device)).cpu().detach().numpy()\n",
    "            costs_nl[norm][3][it][i_emb] = column_projection_distance(emb).cpu().detach().numpy()\n",
    "            costs_nl[norm][4][it][i_emb] = pt.linalg.matrix_rank(emb, atol=1e-1).cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_costs_nl = {k:np.mean(v, axis=1) for k,v in costs.items()}\n",
    "std_costs_nl = {k:np.std(v, axis=1) for k,v in costs.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 18\n",
    "fig = plt.figure(constrained_layout=True, figsize=(15, 3))\n",
    "x = np.arange(0, num_layers)\n",
    "subfigs = fig.subfigures(1, 5, wspace=0.01)\n",
    "ax = subfigs[0].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    #ax.plot(mean_costs[norm][0], label=norm)\n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs[norm][0] + std_costs[norm][0],\n",
    "                 y2=mean_costs[norm][0] - std_costs[norm][0],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs[norm][0], label=norm)\n",
    "#ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$\\mu(X^{(t)})$ -- linear', fontsize=fs-2)\n",
    "\n",
    " \n",
    "ax = subfigs[1].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs_nl[norm][0] + std_costs_nl[norm][0],\n",
    "                 y2=mean_costs_nl[norm][0] - std_costs_nl[norm][0],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs_nl[norm][0], label=norm)\n",
    "#ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$\\mu(X^{(t)})$ -- non-linear', fontsize=fs-2)\n",
    "\n",
    "ax = subfigs[2].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    \n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs[norm][2] + std_costs[norm][2],\n",
    "                 y2=mean_costs[norm][2],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs[norm][2], label=norm)\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$d_{ev}(X)$ -- linear', fontsize=fs-2)\n",
    "\n",
    "\n",
    "ax = subfigs[3].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    ax.plot(mean_costs[norm][4], label=norm)\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$Rank(X^{(t)})$ -- linear', fontsize=fs-2)\n",
    "\n",
    "\n",
    "\n",
    "ax = subfigs[4].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    ax.plot(mean_costs_nl[norm][4], label=norm)\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$Rank(X^{(t)})$ -- non-linear', fontsize=fs-2)\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "leg = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=7, fontsize=11)\n",
    "\n",
    "\n",
    "#fig.savefig('figures/together_reduced_GCN.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = [(norm, it+1, costs[norm][4, dp, it]) if not(it == 255 and norm=='pair') else (norm, it+1, 1) for norm in normalizations for it in [3,15,63,255] for dp in range(costs[norm].shape[1])]\n",
    "data_nl = [(norm, it+1, costs_nl[norm][4, dp, it]) for norm in normalizations for it in [3,15,63,255] for dp in range(costs_nl[norm].shape[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "mean_df = pd.DataFrame(data=data, columns=['Normalization', 'Layer', 'Rank'])\n",
    "mean_df_nl = pd.DataFrame(data=data_nl, columns=['Normalization', 'Layer', 'Rank'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.barplot(data=mean_df_nl, x='Layer', y='Rank', hue='Normalization', errorbar=lambda x: (x.min(), x.max()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.barplot(data=mean_df, x='Layer', y='Rank', hue='Normalization',)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 18\n",
    "fig = plt.figure(constrained_layout=True, figsize=(15, 3))\n",
    "x = np.arange(0, num_layers)\n",
    "subfigs = fig.subfigures(1, 5, wspace=0.01)\n",
    "ax = subfigs[0].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    #ax.plot(mean_costs[norm][0], label=norm)\n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs[norm][0] + std_costs[norm][0],\n",
    "                 y2=mean_costs[norm][0] - std_costs[norm][0],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs[norm][0], label=norm)\n",
    "#ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$\\mu_{v}(X^{(t)})$ -- linear', fontsize=fs-2)\n",
    "\n",
    " \n",
    "ax = subfigs[1].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs_nl[norm][0] + std_costs_nl[norm][0],\n",
    "                 y2=mean_costs_nl[norm][0] - std_costs_nl[norm][0],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs_nl[norm][0], label=norm)\n",
    "#ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$\\mu_{v}(X^{(t)})$ -- non-linear', fontsize=fs-2)\n",
    "\n",
    "ax = subfigs[2].subplots(1,1)\n",
    "for norm in normalizations:\n",
    "    \n",
    "    ax.fill_between(x=x,\n",
    "                 y1=mean_costs[norm][2] + std_costs[norm][2],\n",
    "                 y2=mean_costs[norm][2],\n",
    "                 alpha=0.5\n",
    "                 )\n",
    "    ax.plot(x, mean_costs[norm][2], label=norm)\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_title('$d_{ev}(X^{(t)})$ -- linear', fontsize=fs-2)\n",
    "\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "leg = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=7, fontsize=fs, frameon=False)\n",
    "\n",
    "ax = subfigs[3].subplots(1,1)\n",
    "\n",
    "sns.barplot(data=mean_df, x='Layer', y='Rank', hue='Normalization', legend=False , errorbar='sd')\n",
    "ax.set(xlabel=None, ylabel=None)\n",
    "for i in range(3):\n",
    "    ax.axvline(i + 0.5, color='grey', lw=0.25)\n",
    "    \n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_ylim(0.5, 8.5)\n",
    "ax.set_title('$Rank(X^{(t)})$ -- linear', fontsize=fs-2)\n",
    "\n",
    "\n",
    "\n",
    "ax = subfigs[4].subplots(1,1)\n",
    "sns.barplot(data=mean_df_nl, x='Layer', y='Rank', hue='Normalization', legend=False, errorbar='sd')\n",
    "ax.set(xlabel=None, ylabel=None)\n",
    "for i in range(3):\n",
    "    ax.axvline(i + 0.5, color='grey', lw=0.25)\n",
    "\n",
    "ax.tick_params(axis='both', which='major', labelsize=fs)\n",
    "ax.set_ylim(0.5, 8.5)\n",
    "ax.set_title('$Rank(X^{(t)})$ -- non-linear', fontsize=fs-2)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "fig.savefig('figures/together_reduced_GCN.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv_oan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
