{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os.path as osp\n",
    "import time\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as mlines\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import CitationFull, WikipediaNetwork, Amazon, Coauthor\n",
    "from torch_geometric.logging import log\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.nn import SGConv, SAGEConv, WLConv, WLConvContinuous, GINConv, MLP\n",
    "from gcn_conv import GCNConv\n",
    "from gat_conv import GATConv\n",
    "from torch_geometric.utils import degree, add_remaining_self_loops, k_hop_subgraph, to_networkx, subgraph\n",
    "from torch_scatter import scatter, scatter_mean, scatter_sum\n",
    "import torch_sparse\n",
    "import networkx as nx\n",
    "\n",
    "from utils import *\n",
    "import json\n",
    "from copy import copy\n",
    "import os\n",
    "from os import environ\n",
    "\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "environ['CUDA_LAUNCH_BLOCKING'] = \"1\"\n",
    "dataset = environ.get('dataset', 'CiteSeer')\n",
    "num_gnn_layers = int(environ.get('num_gnn_layers', '3'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('--dataset', type=str, default=dataset)\n",
    "parser.add_argument('--hidden_channels', type=int, default=64)\n",
    "parser.add_argument('--lr', type=float, default=5e-3)\n",
    "parser.add_argument('--epochs', type=int, default=500)\n",
    "parser.add_argument('--num_gnn_layers', type=int, default=num_gnn_layers)\n",
    "parser.add_argument('--num_mlp_layers', type=int, default=num_gnn_layers)\n",
    "args = parser.parse_args(\"\")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda:6')\n",
    "elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed: int):\n",
    "    import random, os\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    \n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(42)\n",
    "transform = T.Compose([\n",
    "                # T.LargestConnectedComponents(),\n",
    "                T.RandomNodeSplit(),\n",
    "                T.NormalizeFeatures()\n",
    "            ])\n",
    "\n",
    "if args.dataset in ['Cora_ML', 'CiteSeer']: \n",
    "    path = osp.join('.', 'data', 'CitationFull')\n",
    "    dataset = CitationFull(path, args.dataset, transform=transform)\n",
    "elif args.dataset in ['CS', 'Physics']:\n",
    "    path = osp.join('.', 'data', 'Coauthor')\n",
    "    dataset = Coauthor(path, args.dataset, transform=transform)\n",
    "elif args.dataset in ['Photo', 'Computers']:\n",
    "    path = osp.join('.', 'data', 'Amazon')\n",
    "    dataset = Amazon(path, args.dataset, transform=transform)\n",
    "elif args.dataset in ['chameleon', 'squirrel']:\n",
    "    path = osp.join('.', 'data', 'WikipediaNetwork')\n",
    "    dataset = WikipediaNetwork(path, args.dataset, transform=transform)\n",
    "\n",
    "data = dataset[0].to(device)\n",
    "threshold = 100\n",
    "seeds_list = [5, 340, 87, 23, 409, 104, 278, 12, 43, 76]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model architectures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, model_type='mlp', linear=False):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.conv_list = torch.nn.ModuleList([torch.nn.Linear(in_channels, hidden_channels)] + \\\n",
    "                                       [torch.nn.Linear(hidden_channels, hidden_channels) for i in range(args.num_mlp_layers - 2)] + \\\n",
    "                                       [torch.nn.Linear(hidden_channels, out_channels)])\n",
    "        self.model_type = model_type\n",
    "        self.linear = linear\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight=None):\n",
    "        for idx, conv in enumerate(self.conv_list):\n",
    "            # x = F.dropout(x, p=0.5, training=self.training)\n",
    "            if idx != 0 and not self.linear:\n",
    "                x = x.relu()\n",
    "            x = conv(x)\n",
    "            \n",
    "        return x\n",
    "    \n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, model_type='gcn', linear=False):\n",
    "        super().__init__()\n",
    "        \n",
    "        if model_type == 'gcn':\n",
    "            conv = GCNConv\n",
    "            conv_type = \"sym\"\n",
    "            self.conv_list = torch.nn.ModuleList([conv(in_channels, in_channels, hidden_channels, conv=conv_type)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, hidden_channels, conv=conv_type) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, out_channels, conv=conv_type)])\n",
    "            \n",
    "        elif model_type == 'sage':\n",
    "            conv = GCNConv\n",
    "            conv_type = \"rw\"\n",
    "            self.conv_list = torch.nn.ModuleList([conv(in_channels, in_channels, hidden_channels, conv=conv_type)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, hidden_channels, conv=conv_type) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, out_channels, conv=conv_type)])\n",
    "#             conv = SAGEConv\n",
    "#             conv_type = None\n",
    "#             self.conv_list = torch.nn.ModuleList([conv(in_channels, hidden_channels)] + \\\n",
    "#                                            [conv(hidden_channels, hidden_channels) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "#                                            [conv(hidden_channels, out_channels)])\n",
    "            \n",
    "        elif 'gat' in model_type:\n",
    "            conv = GATConv\n",
    "            self.conv_list = torch.nn.ModuleList([conv(in_channels, in_channels, hidden_channels)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, hidden_channels) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "                                           [conv(in_channels, hidden_channels, out_channels)])\n",
    "        \n",
    "        elif model_type == 'sgconv':\n",
    "            conv = SGConv\n",
    "            self.conv_list = torch.nn.ModuleList([conv(in_channels, hidden_channels)] + \\\n",
    "                                           [conv(hidden_channels, hidden_channels) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "                                           [conv(hidden_channels, out_channels)])\n",
    "        \n",
    "        elif model_type == 'gin':\n",
    "            conv = GINConv\n",
    "            self.conv_list = torch.nn.ModuleList([conv(nn=MLP([in_channels, hidden_channels, hidden_channels]), train_eps=False)] + \\\n",
    "                                           [conv(nn=MLP([hidden_channels, hidden_channels, hidden_channels]), train_eps=False) for i in range(args.num_gnn_layers - 2)] + \\\n",
    "                                           [conv(nn=MLP([hidden_channels, hidden_channels, out_channels]), train_eps=False)])\n",
    "        \n",
    "        # self.bn_list = torch.nn.ModuleList([torch.nn.BatchNorm1d(hidden_channels) for i in range(args.num_gnn_layers - 1)])\n",
    "        \n",
    "        # self.mlp = MLP(hidden_channels, hidden_channels, out_channels)\n",
    "        self.model_type = model_type\n",
    "        self.linear = linear\n",
    "\n",
    "    def forward(self, x, edge_index, edge_weight=None, get_att=False):\n",
    "        atts = []\n",
    "        x0 = torch.clone(x)\n",
    "        for idx, conv in enumerate(self.conv_list):\n",
    "            # x = F.dropout(x, p=0.5, training=self.training)\n",
    "            if idx != 0 and not self.linear:\n",
    "                x = x.relu()\n",
    "            if get_att:\n",
    "                x, alpha = conv(x, x0, edge_index, edge_weight, get_att=True)\n",
    "                atts.append(alpha)\n",
    "            else:\n",
    "                x = conv(x, x0, edge_index, edge_weight)\n",
    "#             if idx != len(self.conv_list) - 1:\n",
    "#                 x = self.bn_list[idx](x)\n",
    "            \n",
    "        # return self.mlp(x, edge_index, edge_weight)\n",
    "        if get_att:\n",
    "            return x, atts\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data.x, data.edge_index, data.edge_attr)\n",
    "    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n",
    "    \n",
    "    if model.model_type.startswith('fair-'):\n",
    "        pred = F.log_softmax(out, dim=1)[data.train_mask]\n",
    "        target = data.y[data.train_mask]\n",
    "        err = -pred[range(target.size(0)), target].flatten()\n",
    "        \n",
    "        # compute standard deviation\n",
    "        deg = degree(data.edge_index[0], num_nodes=data.x.size(0))[data.train_mask]\n",
    "        err_by_deg = scatter_mean(err, deg.long())\n",
    "        count_by_deg = scatter_sum(torch.ones_like(data.y)[data.train_mask].float(), deg.long())\n",
    "        err_by_deg = err_by_deg[count_by_deg > 0]\n",
    "        \n",
    "        lam = float(model.model_type.split('-')[-1])\n",
    "        loss += lam * torch.std(err_by_deg)\n",
    "    \n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    # scheduler.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model):\n",
    "    model.eval()\n",
    "    pred = model(data.x, data.edge_index, data.edge_attr).argmax(dim=-1)\n",
    "\n",
    "    accs = []\n",
    "    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n",
    "        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n",
    "    return accs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inspect data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# edge_index = data.edge_index\n",
    "# edge_weight = torch.ones_like(edge_index[0]).float()\n",
    "# C_u = get_compatibility_matrix(data.y, edge_index, edge_weight)\n",
    "# fig, ax = plt.subplots()\n",
    "# ax = sns.heatmap(C_u, annot=C_u)\n",
    "# plt.title(dataset.name + \", u\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wl = WLConv()\n",
    "_, colors = data.x.unique(dim=0, return_inverse=True)\n",
    "for _ in range(args.num_gnn_layers):\n",
    "    colors = wl(colors, data.edge_index)\n",
    "\n",
    "X_uniq, inv = colors[data.train_mask].unique(dim=0, return_inverse=True)\n",
    "inv_uniq = inv.unique().tolist()\n",
    "maj_pred = torch.zeros_like(data.y[data.train_mask])\n",
    "\n",
    "for u in inv_uniq:\n",
    "    maj_pred[inv == u] = torch.mode(data.y[data.train_mask][inv == u]).values\n",
    "\n",
    "best_acc = ((maj_pred == data.y[data.train_mask]).sum() / data.y[data.train_mask].size(0)).item()\n",
    "print('Best possible accuracy:', best_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_list = [\n",
    "#                 'mlp',\n",
    "#                 'gin', \\\n",
    "                'sage', \\\n",
    "                'gcn', \\\n",
    "                'gat', \\\n",
    "            ]\n",
    "\n",
    "# for alpha in [1.0, 2.0, 4.0, 8.0, 12.0]:\n",
    "#     model_list.append('fair-gat-{}'.format(str(alpha)))\n",
    "\n",
    "if args.dataset == 'Computers':\n",
    "    args.epochs = 3000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {}\n",
    "overfit_models = {}\n",
    "\n",
    "acc_by_grads = {}\n",
    "\n",
    "low_deg_spreads = {}\n",
    "high_deg_spreads = {}\n",
    "\n",
    "low_deg_losses = {}\n",
    "high_deg_losses = {}\n",
    "\n",
    "loss_by_deg = {}\n",
    "\n",
    "low_deg_svd = {}\n",
    "high_deg_svd = {}\n",
    "\n",
    "for model_name in model_list:\n",
    "    \n",
    "    models[model_name] = {}\n",
    "    overfit_models[model_name] = {}\n",
    "\n",
    "    acc_by_grads[model_name] = {}\n",
    "    \n",
    "    low_deg_spreads[model_name] = {}\n",
    "    high_deg_spreads[model_name] = {}\n",
    "\n",
    "    low_deg_losses[model_name] = {}\n",
    "    high_deg_losses[model_name] = {}\n",
    "\n",
    "    loss_by_deg[model_name] = {}\n",
    "\n",
    "    low_deg_svd[model_name] = {}\n",
    "    high_deg_svd[model_name] = {}\n",
    "    \n",
    "    for seed in seeds_list:\n",
    "        \n",
    "        seed_everything(seed)\n",
    "        print()\n",
    "        print(model_name, seed)\n",
    "\n",
    "        if model_name == 'mlp':\n",
    "            model = MLP(\n",
    "                in_channels=dataset.num_features,\n",
    "                hidden_channels=args.hidden_channels,\n",
    "                out_channels=dataset.num_classes,\n",
    "                model_type=model_name\n",
    "            ).to(device)\n",
    "        else:\n",
    "            model = GNN(\n",
    "                in_channels=dataset.num_features,\n",
    "                hidden_channels=args.hidden_channels,\n",
    "                out_channels=dataset.num_classes,\n",
    "                model_type=model_name\n",
    "            ).to(device)\n",
    "        \n",
    "        acc_by_grads[model_name][seed] = []\n",
    "        \n",
    "        low_deg_spreads[model_name][seed] = []\n",
    "        high_deg_spreads[model_name][seed] = []\n",
    "\n",
    "        low_deg_losses[model_name][seed] = []\n",
    "        high_deg_losses[model_name][seed] = []\n",
    "\n",
    "        low_deg_svd[model_name][seed] = []\n",
    "        high_deg_svd[model_name][seed] = []\n",
    "        \n",
    "        loss_by_deg[model_name][seed] = []\n",
    "    \n",
    "        optimizer = torch.optim.Adam([\n",
    "            dict(params=conv.parameters(), weight_decay=0) for conv in model.conv_list\n",
    "        ], lr=args.lr)\n",
    "        # scheduler = StepLR(optimizer, step_size=250, gamma=0.25)\n",
    "\n",
    "        train_acc = 0\n",
    "        best_val_acc = test_acc = 0\n",
    "        best_model = None\n",
    "        times = []\n",
    "\n",
    "        edge_index = data.edge_index\n",
    "        deg = degree(edge_index[0], num_nodes=data.x.size(0))\n",
    "\n",
    "        epoch = 1\n",
    "        # while train_acc < best_acc:\n",
    "        while epoch < args.epochs + 1:\n",
    "            out = model(data.x, data.edge_index, data.edge_attr)\n",
    "\n",
    "            degs_split = deg[data.test_mask]\n",
    "\n",
    "            # high_cutoff = torch.quantile(degs_split, q=1-ratio)\n",
    "            # high_degs_mask = degs_split >= high_cutoff\n",
    "            high_idx = torch.topk(degs_split, k=threshold)[1]\n",
    "            high_degs_mask = torch.zeros_like(degs_split).bool()\n",
    "            high_degs_mask[high_idx] = True\n",
    "\n",
    "            # low_cutoff = torch.quantile(degs_split, q=ratio)\n",
    "            # low_degs_mask = degs_split <= low_cutoff\n",
    "            low_idx = torch.topk(-degs_split, k=threshold)[1]\n",
    "            low_degs_mask = torch.zeros_like(degs_split).bool()\n",
    "            low_degs_mask[low_idx] = True\n",
    "\n",
    "            low_deg_spread, high_deg_spread = scatter(out[data.test_mask], high_degs_mask, low_degs_mask, data)\n",
    "            low_deg_spreads[model_name][seed].append(low_deg_spread.item())\n",
    "            high_deg_spreads[model_name][seed].append(high_deg_spread.item())\n",
    "            \n",
    "            pred = F.log_softmax(out, dim=1)\n",
    "            target = data.y\n",
    "            err = -pred[range(target.size(0)), target]\n",
    "\n",
    "            train_degs = deg[data.train_mask]\n",
    "            train_low_idx = torch.topk(-train_degs, k=threshold)[1]\n",
    "            train_high_idx = torch.topk(train_degs, k=threshold)[1]\n",
    "\n",
    "            low_deg_losses[model_name][seed].append(err[data.train_mask][train_low_idx].mean().item())\n",
    "            high_deg_losses[model_name][seed].append(err[data.train_mask][train_high_idx].mean().item())\n",
    "\n",
    "            if epoch == args.epochs:\n",
    "                mask = data.test_mask\n",
    "                U, S, V = torch.svd(out[mask])\n",
    "\n",
    "                low_deg_svd[model_name][seed].append((U[:, 0][low_degs_mask].tolist(), U[:, 1][low_degs_mask].tolist(), data.y[mask][low_degs_mask].tolist()))\n",
    "                high_deg_svd[model_name][seed].append((U[:, 0][high_degs_mask].tolist(), U[:, 1][high_degs_mask].tolist(), data.y[mask][high_degs_mask].tolist()))\n",
    "                \n",
    "            start = time.time()\n",
    "            loss = train(model)\n",
    "            train_acc, val_acc, tmp_test_acc = test(model)\n",
    "            if val_acc > best_val_acc:\n",
    "                best_val_acc = val_acc\n",
    "                test_acc = tmp_test_acc\n",
    "                best_model = model\n",
    "            times.append(time.time() - start)\n",
    "            \n",
    "            total_norm = 0\n",
    "            num_params = 0\n",
    "            for p in model.parameters():\n",
    "                if p.grad is not None:\n",
    "                    total_norm += torch.abs(p.grad.detach().data).sum().item()\n",
    "                    num_params += p.grad.detach().data.numel()\n",
    "            total_norm = total_norm / num_params\n",
    "            acc_by_grads[model_name][seed].append((total_norm, train_acc))\n",
    "\n",
    "            if epoch % 100 == 0:\n",
    "                log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)\n",
    "            epoch += 1\n",
    "\n",
    "        models[model_name][seed] = best_model\n",
    "        overfit_models[model_name][seed] = model\n",
    "        \n",
    "        out = models[model_name][seed](data.x, data.edge_index, data.edge_attr)\n",
    "        pred = F.log_softmax(out, dim=1)\n",
    "        target = data.y\n",
    "        err = -pred[range(target.size(0)), target]\n",
    "        \n",
    "#         pred = pred.argmax(dim=-1)\n",
    "#         err = (pred != data.y).float()\n",
    "                \n",
    "        err_by_deg = torch.zeros(int(deg.max()) + 1, device=deg.device)\n",
    "        count_by_deg = torch.zeros(int(deg.max()) + 1, device=deg.device)\n",
    "        \n",
    "        scatter_mean(err[data.test_mask], deg[data.test_mask].long(), out=err_by_deg)\n",
    "        scatter_sum(torch.ones_like(data.y)[data.test_mask].float(), deg[data.test_mask].long(), out=count_by_deg)\n",
    "        \n",
    "        loss_by_deg[model_name][seed] = err_by_deg[count_by_deg > 0].tolist()\n",
    "                \n",
    "        print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('plot_data/{}_{}.json'.format(args.dataset, args.num_gnn_layers), 'w') as fout:\n",
    "    json.dump([loss_by_deg, \\\n",
    "               acc_by_grads, \\\n",
    "              low_deg_svd, \\\n",
    "              high_deg_svd, \\\n",
    "              low_deg_spreads, \\\n",
    "              high_deg_spreads, \\\n",
    "              low_deg_losses, \\\n",
    "              high_deg_losses], fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('plot_data/{}_{}.json'.format(args.dataset, args.num_gnn_layers), 'r') as fout:\n",
    "    [loss_by_deg, \\\n",
    "      acc_by_grads, \\\n",
    "      low_deg_svd, \\\n",
    "      high_deg_svd, \\\n",
    "      low_deg_spreads, \\\n",
    "      high_deg_spreads, \\\n",
    "      low_deg_losses, \\\n",
    "      high_deg_losses] = json.load(fout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plots and Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 20})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_name_map(model_name):\n",
    "    if model_name == 'gcn':\n",
    "        return 'SYM'\n",
    "    elif model_name == 'sage':\n",
    "        return 'RW'\n",
    "    elif model_name == 'gat':\n",
    "        return 'ATT'\n",
    "    elif 'fair-' in model_name:\n",
    "        return r'FAIR ATT ($\\lambda$ = {})'.format(model_name.split('-')[-1])\n",
    "    else:\n",
    "        return 'MLP'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unfair_models_to_plot = [model_name for model_name in model_list if not model_name.startswith('fair-')]\n",
    "\n",
    "fair_models_to_plot = [model_name for model_name in model_list if model_name.startswith('fair-')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 16})\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, figsize=(10, 5))\n",
    "\n",
    "for model_name in unfair_models_to_plot:\n",
    "    grads = torch.tensor([[e[0] for e in acc_by_grads[model_name][seed]] for seed in acc_by_grads[model_name]])\n",
    "    accs = torch.tensor([[e[1] for e in acc_by_grads[model_name][seed]] for seed in acc_by_grads[model_name]])\n",
    "    \n",
    "    grads_y = np.array(grads.mean(dim=0).tolist())\n",
    "    grads_y_err = np.array(grads.std(dim=0).tolist())\n",
    "    accs_y = np.array(accs.mean(dim=0).tolist())\n",
    "    accs_y_err = np.array(accs.std(dim=0).tolist())\n",
    "    x = np.array(range(len(grads_y)))\n",
    "    \n",
    "    cutoff = -100\n",
    "    ax1.plot(x[cutoff:], grads_y[cutoff:], label=model_name_map(model_name))\n",
    "    ax1.fill_between(x[cutoff:], (grads_y-grads_y_err)[cutoff:], (grads_y+grads_y_err)[cutoff:],\n",
    "                            alpha=0.2, # edgecolor='C1', facecolor='C1',\n",
    "                            linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "    \n",
    "    ax2.plot(x[cutoff:], accs_y[cutoff:], label=model_name_map(model_name))\n",
    "    ax2.fill_between(x[cutoff:], (accs_y-accs_y_err)[cutoff:], (accs_y+accs_y_err)[cutoff:],\n",
    "                            alpha=0.2, # edgecolor='C1', facecolor='C1',\n",
    "                            linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "    \n",
    "\n",
    "accs = [best_acc for _ in x[cutoff:]]\n",
    "ax2.plot(x[cutoff:], accs, linestyle='dashed', label=r\"${\\mathrm{MAJ}}_{\\mathrm{WL}}$\")\n",
    "\n",
    "ax1.set(ylim=0, xlabel='Epoch', ylabel='Mean absolute gradient')\n",
    "accs_range = (max(accs) - min(accs)) / 2\n",
    "ax2.set(ylim=(min(accs) - accs_range, max(accs) + accs_range), xlabel='Epoch', ylabel='Training accuracy')\n",
    "ax1.legend()\n",
    "ax2.legend()\n",
    "plt.suptitle(args.dataset.upper())\n",
    "plt.tight_layout()\n",
    "plt.savefig('plots/{}_{}_loss_vs_grad.pdf'.format(args.dataset, args.num_gnn_layers))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# matplotlib.rcParams.update({'font.size': 16})\n",
    "\n",
    "# fig, axs = plt.subplots(1, args.num_gnn_layers + 1, figsize=(15, 5))\n",
    "\n",
    "# # for norm_idx, (norm, norm_type) in enumerate([(gcn_norm, 'sym'), (sage_norm, 'rw')]):\n",
    "# edge_index = data.edge_index\n",
    "# edge_weight = torch.ones_like(edge_index[0]).float()\n",
    "# edge_weight = gcn_norm(edge_index, edge_weight, data.x.size(0))\n",
    "\n",
    "# for l in range(args.num_gnn_layers + 1):\n",
    "#     pow_edge_index, pow_edge_weight = power(edge_index, edge_weight, data.y.size(0), k=l)\n",
    "\n",
    "#     dense_adj = torch.sparse.FloatTensor(pow_edge_index, pow_edge_weight).detach().cpu().to_dense()\n",
    "#     eigvals = (-torch.sort(-torch.linalg.eigvals(dense_adj).real)[0]).tolist()\n",
    "#     axs[l].scatter(range(len(eigvals)), eigvals)\n",
    "\n",
    "#     axs[l].set(xlabel='Index', ylabel='Singular value', title=fr\"$P_{{sym}}^{l}, P_{{rw}}^{l}$\")\n",
    "\n",
    "# plt.suptitle(args.dataset.upper())\n",
    "# plt.tight_layout()\n",
    "# plt.savefig('plots/{}_{}_rank.pdf'.format(args.dataset, args.num_gnn_layers))\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for models_to_plot in [unfair_models_to_plot]: #, fair_models_to_plot[::2]]:\n",
    "    ylim_min = None\n",
    "    ylim_max = None\n",
    "    \n",
    "    fig, axs = plt.subplots(1, len(models_to_plot), figsize=(5 * len(models_to_plot), 5), constrained_layout=True, sharex='row', sharey='row')\n",
    "    \n",
    "    for idx, model_name in enumerate(models_to_plot):\n",
    "\n",
    "        deg_loss = [loss_by_deg[model_name][seed] for seed in loss_by_deg[model_name]]\n",
    "        deg_loss_mean = np.array(deg_loss).mean(axis=0)\n",
    "        deg_loss_std = np.array(deg_loss).std(axis=0)\n",
    "\n",
    "        deg = degree(data.edge_index[0], num_nodes=data.x.size(0))\n",
    "        count_by_deg = torch.zeros(int(deg.max()) + 1, device=deg.device)\n",
    "        scatter_sum(torch.ones_like(data.y)[data.test_mask].float(), deg[data.test_mask].long(), out=count_by_deg)\n",
    "        deg_x = torch.arange((deg.max()) + 1, device=deg.device)[count_by_deg > 0].tolist()\n",
    "\n",
    "        axs[idx].errorbar(deg_x, deg_loss_mean, yerr=deg_loss_std, color='C0', fmt='o')\n",
    "        axs[idx].set(xlabel='Degree', ylabel='Test loss', title=model_name_map(model_name))\n",
    "        axs[idx].yaxis.set_tick_params(labelbottom=True)\n",
    "        \n",
    "        if ylim_min is None:\n",
    "            ylim_min = min(deg_loss_mean) - 0.1\n",
    "        else:\n",
    "            ylim_min = min(ylim_min, min(deg_loss_mean) - 0.1)\n",
    "            \n",
    "        if ylim_max is None:\n",
    "            ylim_max = max(deg_loss_mean) + 0.1\n",
    "        else:\n",
    "            ylim_max = max(ylim_max, max(deg_loss_mean) + 0.1)\n",
    "\n",
    "    plt.setp(axs, ylim=(ylim_min, ylim_max))\n",
    "        \n",
    "    plt.suptitle(args.dataset.upper())\n",
    "    plt.savefig('plots/{}_{}_loss_vs_deg.pdf'.format(args.dataset, args.num_gnn_layers))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for models_type, models_to_plot in [('_', unfair_models_to_plot), ('_fair_', fair_models_to_plot[1:-1])]:\n",
    "    \n",
    "    if len(models_to_plot) == 0:\n",
    "        continue\n",
    "    \n",
    "    fig, axs = plt.subplots(len(models_to_plot), 4, figsize=(20, 5 * len(models_to_plot)), constrained_layout=True, sharex='col', sharey='col', \\\n",
    "                       gridspec_kw={'width_ratios': [3, 3, 2, 1]})\n",
    "    \n",
    "    for idx, model_name in enumerate(models_to_plot):\n",
    "\n",
    "        low_U0 = [e[0] for e in low_deg_svd[model_name][str(seeds_list[0])]]\n",
    "        low_U1 = [e[1] for e in low_deg_svd[model_name][str(seeds_list[0])]]\n",
    "        low_c = [e[2] for e in low_deg_svd[model_name][str(seeds_list[0])]]\n",
    "\n",
    "        res = axs[idx, 0].scatter(low_U0, low_U1, c=low_c, marker='x', label='Low-degree nodes')\n",
    "\n",
    "        high_U0 = [e[0] for e in high_deg_svd[model_name][str(seeds_list[0])]]\n",
    "        high_U1 = [e[1] for e in high_deg_svd[model_name][str(seeds_list[0])]]\n",
    "        high_c = [e[2] for e in high_deg_svd[model_name][str(seeds_list[0])]]\n",
    "\n",
    "        axs[idx, 0].scatter(high_U0, high_U1, c=high_c, label='High-degree nodes')\n",
    "\n",
    "        low_label = mlines.Line2D([], [], color='k', marker='x', ls='', label='Low-degree nodes')\n",
    "        high_label = mlines.Line2D([], [], color='k', marker='o', ls='', label='High-degree nodes')\n",
    "        axs[idx, 0].legend(handles=[low_label, high_label], prop={'size': 16})\n",
    "        axs[idx, 0].set(xlabel='PC1', ylabel='PC2', title=model_name_map(model_name))\n",
    "\n",
    "        axs[idx, 0].xaxis.set_tick_params(labelbottom=True)\n",
    "        \n",
    "#         if idx > 0:\n",
    "#             axs[idx, 0].get_shared_x_axes().remove(axs[idx - 1, 0])\n",
    "#             axs[idx, 0].get_shared_y_axes().remove(axs[idx - 1, 0])\n",
    "        \n",
    "        ###\n",
    "\n",
    "        low_y = [low_deg_spreads[model_name][seed] for seed in low_deg_spreads[model_name]]\n",
    "        low_y_mean = np.array(low_y).mean(axis=0)\n",
    "        low_y_std = np.array(low_y).std(axis=0)\n",
    "\n",
    "        high_y = [high_deg_spreads[model_name][seed] for seed in high_deg_spreads[model_name]]\n",
    "        high_y_mean = np.array(high_y).mean(axis=0)\n",
    "        high_y_std = np.array(high_y).std(axis=0)\n",
    "\n",
    "        axs[idx, 1].plot(low_y_mean, label='Low-degree nodes', color='C0')\n",
    "        axs[idx, 1].fill_between(range(len(low_y_mean)), low_y_mean-low_y_std, low_y_mean+low_y_std,\n",
    "                            alpha=0.2, edgecolor='C0', facecolor='C0',\n",
    "                            linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "        axs[idx, 1].plot(high_y_mean, label='High-degree nodes', color='C1')\n",
    "        axs[idx, 1].fill_between(range(len(high_y_mean)), high_y_mean-high_y_std, high_y_mean+high_y_std,\n",
    "                            alpha=0.2, edgecolor='C1', facecolor='C1',\n",
    "                            linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "\n",
    "        axs[idx, 1].set(xlabel='Epochs of training', ylabel='Trace of sample covariance\\nof test representations', \\\n",
    "                        title=model_name_map(model_name))\n",
    "        axs[idx, 1].legend(prop={'size': 16})\n",
    "        \n",
    "        axs[idx, 1].xaxis.set_tick_params(labelbottom=True)\n",
    "\n",
    "        ###\n",
    "\n",
    "        if idx > 0:\n",
    "            axs[idx, 2].get_shared_x_axes().remove(axs[0, 2])\n",
    "            axs[idx, 3].get_shared_x_axes().remove(axs[0, 3])\n",
    "            axs[idx, 2].get_shared_y_axes().join(axs[idx, 2], axs[idx, 3])\n",
    "\n",
    "        low_y = [low_deg_losses[model_name][seed] for seed in low_deg_losses[model_name]]\n",
    "        low_y_mean = np.array(low_y).mean(axis=0)\n",
    "        low_y_std = np.array(low_y).std(axis=0)\n",
    "\n",
    "        high_y = [high_deg_losses[model_name][seed] for seed in high_deg_losses[model_name]]\n",
    "        high_y_mean = np.array(high_y).mean(axis=0)\n",
    "        high_y_std = np.array(high_y).std(axis=0)\n",
    "\n",
    "        for ax in [axs[idx, 2], axs[idx, 3]]:\n",
    "            ax.plot(low_y_mean, label='Low-degree nodes', color='C0')\n",
    "            ax.fill_between(range(len(low_y_mean)), low_y_mean-low_y_std, low_y_mean+low_y_std,\n",
    "                                alpha=0.2, edgecolor='C0', facecolor='C0',\n",
    "                                linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "            ax.plot(high_y_mean, label='High-degree nodes', color='C1')\n",
    "            ax.fill_between(range(len(high_y_mean)), high_y_mean-high_y_std, high_y_mean+high_y_std,\n",
    "                                alpha=0.2, edgecolor='C1', facecolor='C1',\n",
    "                                linewidth=4, linestyle='dashdot', antialiased=True)\n",
    "\n",
    "        axs[idx, 2].set(xlabel='Epochs of training', ylabel='Training loss', \\\n",
    "                        title=model_name_map(model_name))\n",
    "        axs[idx, 2].legend(prop={'size': 16})\n",
    "        axs[idx, 2].set_xlim(0, 100)\n",
    "        axs[idx, 3].set_xlim(args.epochs - 50, args.epochs)\n",
    "\n",
    "        axs[idx, 2].spines['right'].set_visible(False)\n",
    "        axs[idx, 3].spines['left'].set_visible(False)\n",
    "        axs[idx, 3].set_yticks([])\n",
    "\n",
    "        d = .015\n",
    "        kwargs = dict(transform=axs[idx, 2].transAxes, color='k', clip_on=False)\n",
    "        axs[idx, 2].plot((1-d, 1+d), (-d, +d), **kwargs)\n",
    "        axs[idx, 2].plot((1-d, 1+d), (1-d, 1+d), **kwargs)\n",
    "\n",
    "        kwargs.update(transform=axs[idx, 3].transAxes)\n",
    "        axs[idx, 3].plot((-d, +d), (1-d, 1+d), **kwargs)\n",
    "        axs[idx, 3].plot((-d, +d), (-d, +d), **kwargs)\n",
    "        \n",
    "        axs[idx, 2].xaxis.set_tick_params(labelbottom=True)\n",
    "        axs[idx, 3].xaxis.set_tick_params(labelbottom=True)\n",
    "\n",
    "    plt.suptitle(args.dataset.upper())\n",
    "    plt.savefig('plots/{}_{}{}theorem_validation.pdf'.format(args.dataset, args.num_gnn_layers, models_type))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# csv = []\n",
    "\n",
    "# for model_name in model_list:\n",
    "#     print()\n",
    "#     print(model_name)\n",
    "    \n",
    "#     fairness_scores = []\n",
    "#     util_scores = []\n",
    "#     for seed in seeds_list:\n",
    "    \n",
    "#         mask = data.test_mask\n",
    "\n",
    "#         model = models[model_name][seed]\n",
    "#         out = model(data.x, data.edge_index, data.edge_attr)\n",
    "\n",
    "#         pred = F.log_softmax(out, dim=1)[mask]\n",
    "#         target = data.y[mask]\n",
    "#         err = -pred[range(target.size(0)), target].flatten()\n",
    "\n",
    "#         deg = degree(data.edge_index[0], num_nodes=data.x.size(0))[mask]\n",
    "#         err_by_deg = scatter_mean(err, deg.long())\n",
    "#         count_by_deg = scatter_sum(torch.ones_like(data.y)[mask].float(), deg.long())\n",
    "\n",
    "#         err_by_deg = err_by_deg[count_by_deg > 0]\n",
    "#         fairness_scores.append(torch.var(err_by_deg).item())       \n",
    "#         util_scores.append(torch.mean(err).item())\n",
    "    \n",
    "#     csv.append([model_name_map(model_name), 'fairness', np.array(fairness_scores).mean(), np.array(fairness_scores).std()])\n",
    "#     csv.append([model_name_map(model_name), 'utility', np.array(util_scores).mean(), np.array(util_scores).std()])\n",
    "    \n",
    "#     print('Fairness:', np.array(fairness_scores).mean(), '±', np.array(fairness_scores).std())\n",
    "#     print('Utility:', np.array(util_scores).mean(), '±', np.array(util_scores).std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pandas as pd\n",
    "\n",
    "# pd.DataFrame(csv).to_csv('results/{}_{}.csv'.format(args.dataset, args.num_gnn_layers), header=False, index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collision Probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 18})\n",
    "\n",
    "fig, axs = plt.subplots(1, len(unfair_models_to_plot), figsize=(5 * len(unfair_models_to_plot), 5), constrained_layout=True, sharex='col', sharey='col')\n",
    "\n",
    "deg = degree(data.edge_index[0], num_nodes=data.x.size(0))[data.test_mask]\n",
    "\n",
    "for idx, model_name in enumerate(unfair_models_to_plot):\n",
    "    \n",
    "    if model_name == 'gat':\n",
    "        model = models[model_name][seeds_list[0]]\n",
    "        out, atts = model(data.x, data.edge_index, data.edge_attr, get_att=True)\n",
    "\n",
    "        t_edge_index = torch.cat([data.edge_index[1].reshape(1, -1), data.edge_index[0].reshape(1, -1)])\n",
    "        pow_edge_index = t_edge_index.clone()\n",
    "        n = data.x.size(0)\n",
    "        pow_edge_index, pow_edge_weight = get_id(n, pow_edge_index)\n",
    "        \n",
    "        ent = torch_sparse.spmm(pow_edge_index, pow_edge_weight ** 2, data.y.size(0), data.y.size(0), torch.ones_like(data.y).reshape(-1, 1))\n",
    "        for k in range(len(atts)):\n",
    "            pow_edge_index, pow_edge_weight = torch_sparse.spspmm(t_edge_index, atts[k].flatten(), pow_edge_index, pow_edge_weight, n, n, n)\n",
    "            ent += torch_sparse.spmm(pow_edge_index, pow_edge_weight ** 2, data.y.size(0), data.y.size(0), torch.ones_like(data.y).reshape(-1, 1))\n",
    "    \n",
    "    else:\n",
    "        edge_index = data.edge_index\n",
    "        edge_weight = torch.ones_like(edge_index[0]).float()\n",
    "\n",
    "        _, deg_inv_sqrt = gcn_norm(edge_index, edge_weight, data.x.size(0), return_inv_sqrt=True)\n",
    "        edge_weight = sage_norm(edge_index, edge_weight, data.x.size(0))\n",
    "\n",
    "        deg = degree(edge_index[0], num_nodes=data.x.size(0))\n",
    "\n",
    "        t_edge_index = torch.cat([edge_index[1].reshape(1, -1), edge_index[0].reshape(1, -1)])\n",
    "        ent = torch.zeros_like(data.y).float()\n",
    "        for l in range(args.num_gnn_layers + 1):\n",
    "            pow_edge_index, pow_edge_weight = power(t_edge_index, edge_weight, data.y.size(0), k=l)\n",
    "                \n",
    "            if model_name == 'sage':\n",
    "                agg = torch.ones_like(data.y).reshape(-1, 1)\n",
    "            elif model_name == 'gcn':\n",
    "                agg = (deg_inv_sqrt ** 2 * torch.ones_like(data.y)).reshape(-1, 1)\n",
    "                \n",
    "            ent += torch_sparse.spmm(pow_edge_index, pow_edge_weight ** 2, data.y.size(0), data.y.size(0), agg).flatten()\n",
    "\n",
    "    axs[idx].scatter(deg[data.test_mask].tolist(), (1 / ent[data.test_mask]).tolist())\n",
    "    axs[idx].set(xlabel=r'Degree', ylabel=r'Inverse collision probability', title=model_name_map(model_name))\n",
    "\n",
    "plt.suptitle(args.dataset.upper())\n",
    "plt.savefig('plots/{}_{}_icp_vs_deg.pdf'.format(args.dataset, args.num_gnn_layers))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 16})\n",
    "\n",
    "fair_models_to_plot = [model_name for model_name in model_list if model_name.startswith('fair-')][1:-1]\n",
    "\n",
    "if len(fair_models_to_plot) > 0:\n",
    "    fig, axs = plt.subplots(1, len(fair_models_to_plot), figsize=(5 * len(fair_models_to_plot), 5), constrained_layout=True, sharex='col', sharey='col')\n",
    "\n",
    "    ents = []\n",
    "\n",
    "    deg = degree(data.edge_index[0], num_nodes=data.x.size(0))[data.test_mask]\n",
    "\n",
    "    for model_name in ['gat'] + [model_name for model_name in model_list if model_name.startswith('fair-')]:\n",
    "        model = models[model_name][seeds_list[0]]\n",
    "        out, atts = model(data.x, data.edge_index, data.edge_attr, get_att=True)\n",
    "\n",
    "        t_edge_index = torch.cat([data.edge_index[1].reshape(1, -1), data.edge_index[0].reshape(1, -1)])\n",
    "        pow_edge_index = t_edge_index.clone()\n",
    "        n = data.x.size(0)\n",
    "        pow_edge_index, pow_edge_weight = get_id(n, pow_edge_index)\n",
    "\n",
    "        ent = torch_sparse.spmm(pow_edge_index, pow_edge_weight ** 2, data.y.size(0), data.y.size(0), torch.ones_like(data.y).reshape(-1, 1))\n",
    "        for k in range(len(atts)):\n",
    "            pow_edge_index, pow_edge_weight = torch_sparse.spspmm(t_edge_index, atts[k].flatten(), pow_edge_index, pow_edge_weight, n, n, n)\n",
    "            ent += torch_sparse.spmm(pow_edge_index, pow_edge_weight ** 2, data.y.size(0), data.y.size(0), torch.ones_like(data.y).reshape(-1, 1))\n",
    "        ents.append(ent[data.test_mask])\n",
    "\n",
    "    for k, model_name in enumerate(fair_models_to_plot):\n",
    "        axs[k].scatter(deg.tolist(), ((1 / ents[k + 1]) / (1 / ents[0])).tolist())\n",
    "        axs[k].axhline(y=1.0, color='C1', linestyle='--')\n",
    "\n",
    "        axs[k].set(xlabel='Degree', ylabel='Ratio of inverse\\ncollision probabilities', title=f\"{model_name_map(model_name)} vs. {model_name_map('gat')}\")\n",
    "\n",
    "    plt.suptitle(args.dataset.upper())\n",
    "    plt.savefig('plots/{}_{}_icp_ratio_vs_deg.pdf'.format(args.dataset, args.num_gnn_layers))\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
