{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85ecde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "from os import environ\n",
    "import time\n",
    "\n",
    "import torch\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.datasets import CitationFull, Coauthor, WebKB, LastFMAsia, Twitch\n",
    "from dblp_fairness import DBLPFairness\n",
    "from say_no import SayNo\n",
    "from nifty import Nifty\n",
    "\n",
    "from gcn_conv import GCNConv\n",
    "from torch_geometric.utils import negative_sampling, to_dense_adj, add_remaining_self_loops, is_undirected, dense_to_sparse, subgraph, contains_self_loops, coalesce\n",
    "from torch_geometric.data import Data\n",
    "from connected_classes import ConnectedClasses, LargestConnectedComponents, LargestBiconnectedComponents\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objs as go\n",
    "import plotly.express as px\n",
    "import math\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import re\n",
    "import pickle as pkl\n",
    "\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b1f546",
   "metadata": {},
   "outputs": [],
   "source": [
    "environ['CUDA_LAUNCH_BLOCKING'] = \"1\"\n",
    "dataset_name = environ.get('dataset_name', 'DBLPFairness')\n",
    "conv_type = environ.get('conv_type', 'sym')\n",
    "k = int(environ.get('k', '4'))\n",
    "fairness_regularizer = float(environ.get('fairness_regularizer', '1.0'))\n",
    "\n",
    "seeds = [34, 87, 120, 11, 93, 24, 25, 56, 49, 54]\n",
    "fairness_datasets = [\"DBLPFairness\", \"NBA\", \"German\"]\n",
    "inspect_deviations = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2be1a921",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) 2023 PyG Team <team@pyg.org>\n",
    "\n",
    "# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
    "# of this software and associated documentation files (the \"Software\"), to deal\n",
    "# in the Software without restriction, including without limitation the rights\n",
    "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
    "# copies of the Software, and to permit persons to whom the Software is\n",
    "# furnished to do so, subject to the following conditions:\n",
    "\n",
    "# The above copyright notice and this permission notice shall be included in\n",
    "# all copies or substantial portions of the Software.\n",
    "\n",
    "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
    "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
    "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
    "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
    "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
    "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
    "# THE SOFTWARE.\n",
    "\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, in_channels, hidden_channels, out_channels, conv_type):\n",
    "        super().__init__()\n",
    "        self.convs = torch.nn.ModuleList([GCNConv(in_channels, hidden_channels, conv=conv_type)] + \\\n",
    "                                [GCNConv(hidden_channels, hidden_channels, conv=conv_type) for i in range(k - 2)] + \\\n",
    "                                [GCNConv(hidden_channels, out_channels, conv=conv_type)])\n",
    "\n",
    "    def encode(self, x, edge_index):\n",
    "        for conv in self.convs[:-1]:\n",
    "            x = conv(x, edge_index).relu()\n",
    "        return self.convs[-1](x, edge_index)\n",
    "\n",
    "    def decode(self, z, edge_label_index):\n",
    "        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)\n",
    "\n",
    "    def decode_all(self, z):\n",
    "        prob_adj = z @ z.t()\n",
    "        return prob_adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67109a67",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_data):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    x = train_data.x\n",
    "    z = model.encode(x, train_data.edge_index)\n",
    "\n",
    "    # We perform a new round of negative sampling for every training epoch:\n",
    "    neg_edge_index = negative_sampling(\n",
    "        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,\n",
    "        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')\n",
    "\n",
    "    edge_label_index = torch.cat(\n",
    "        [train_data.edge_label_index, neg_edge_index],\n",
    "        dim=-1,\n",
    "    )\n",
    "    edge_label = torch.cat([\n",
    "        train_data.edge_label,\n",
    "        train_data.edge_label.new_zeros(neg_edge_index.size(1))\n",
    "    ], dim=0)\n",
    "    \n",
    "    edge_label_index, edge_label = coalesce(edge_label_index, edge_label)\n",
    "    \n",
    "    out = model.decode(z, edge_label_index).view(-1)\n",
    "    loss = criterion(out, edge_label)\n",
    "    \n",
    "    fairness_loss = torch.tensor(0.0, device=device)\n",
    "    num_classes = 0\n",
    "    if dataset_name in fairness_datasets and fairness_regularizer > 0:\n",
    "        for c in range(int(train_data.y.max()) + 1):\n",
    "            fairness_score = fairness(c, train_data, edge_label_index, out)\n",
    "            if fairness_score is not None:\n",
    "                num_classes += 1\n",
    "                fairness_loss += fairness_score \n",
    "    if num_classes > 0:\n",
    "        loss += fairness_regularizer * fairness_loss / num_classes\n",
    "    assert not fairness_loss.isnan().any()\n",
    "    \n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "        \n",
    "    return loss\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(model, data):\n",
    "    model.eval()\n",
    "    x = data.x\n",
    "    z = model.encode(x, data.edge_index)\n",
    "    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()\n",
    "    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4893f613",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fairness(c, data, edge_index, probs):\n",
    "    c_mask = data.y == c\n",
    "    c0_mask = c_mask & (data.sens == 0)\n",
    "    c1_mask = c_mask & (data.sens == 1)\n",
    "    \n",
    "    c0a_mask = c0_mask[edge_index[0]] & c_mask[edge_index[1]]\n",
    "    c1a_mask = c1_mask[edge_index[0]] & c_mask[edge_index[1]]\n",
    "\n",
    "    if probs[c0a_mask].size(0) == 0 or probs[c1a_mask].size(0) == 0:\n",
    "        return None\n",
    "    \n",
    "    total_0 = torch.sigmoid(probs[c0a_mask]).mean()\n",
    "    total_1 = torch.sigmoid(probs[c1a_mask]).mean()\n",
    "    return torch.abs(total_0 - total_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc60e5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed: int):\n",
    "    import random, os\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    \n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3270bb68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Union\n",
    "\n",
    "from torch_geometric.data import Data, HeteroData\n",
    "from torch_geometric.transforms import BaseTransform\n",
    "\n",
    "class NormalizeFeatures(BaseTransform):\n",
    "    def __init__(self, attrs: List[str] = [\"x\"]):\n",
    "        self.attrs = attrs\n",
    "\n",
    "    def __call__(self, data: Data) -> Data:\n",
    "        for store in data.stores:\n",
    "            for key, value in store.items(*self.attrs):\n",
    "                if value.numel() > 0:\n",
    "                    min_values = value.min(dim=0)[0]\n",
    "                    max_values = value.max(dim=0)[0]\n",
    "                    store[key] = 2 * (value - min_values).div(max_values - min_values) - 1\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e391e9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(dataset_name):    \n",
    "    if dataset_name in [\"Cora\", \"Cora_ML\", \"CiteSeer\", \"DBLP\", \"PubMed\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            T.NormalizeFeatures(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'CitationFull')\n",
    "        dataset = CitationFull(path, name=dataset_name, transform=transform)\n",
    "    elif dataset_name in [\"CS\", \"Physics\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'Coauthor')\n",
    "        dataset = Coauthor(path, name=dataset_name, transform=transform)\n",
    "    elif dataset_name in [\"LastFMAsia\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'LastFMAsia')\n",
    "        dataset = LastFMAsia(path, transform=transform)\n",
    "        dataset.name = \"LastFMAsia\"\n",
    "    elif dataset_name in [\"DE\", \"EN\", \"FR\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'Twitch')\n",
    "        dataset = Twitch(path, name=dataset_name, transform=transform)\n",
    "    elif dataset_name in [\"DBLPFairness\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            NormalizeFeatures(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'DBLPFairness')\n",
    "        dataset = DBLPFairness(path, transform=transform)\n",
    "        dataset.name = \"DBLPFairness\"\n",
    "    elif dataset_name in [\"Pokec-z\", \"Pokec-n\", \"NBA\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            NormalizeFeatures(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'SayNo')\n",
    "        dataset = SayNo(path, name=dataset_name, transform=transform)\n",
    "    elif dataset_name in [\"Credit\", \"German\"]:\n",
    "        transform = T.Compose([ \n",
    "            ConnectedClasses(),\n",
    "            NormalizeFeatures(),\n",
    "            T.ToDevice(device),\n",
    "            T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n",
    "                              add_negative_train_samples=False),\n",
    "        ])\n",
    "        path = osp.join('.', 'data', 'Nifty')\n",
    "        dataset = Nifty(path, name=dataset_name, transform=transform)\n",
    "    else:\n",
    "        raise ValueError\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0a34fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "dataset = load_dataset(dataset_name)\n",
    "test_data = dataset[0][2]\n",
    "num_labels = int(test_data.y.max()) + 1\n",
    "\n",
    "x_list = [[[] for _ in range(len(seeds))] for _ in range(num_labels)]\n",
    "y_list = [[[] for _ in range(len(seeds))] for _ in range(num_labels)]\n",
    "pcc_list = []\n",
    "nrmse_list = []\n",
    "auc_list = []\n",
    "\n",
    "feat_sim_list = []\n",
    "deg_sim_list = []\n",
    "colors_list = []\n",
    "\n",
    "x_fairness_scores = [[] for _ in range(num_labels)]\n",
    "y_fairness_scores = [[] for _ in range(num_labels)]\n",
    "\n",
    "for seed_idx, seed in enumerate(seeds):\n",
    "    seed_everything(0)\n",
    "    \n",
    "    dataset = load_dataset(dataset_name)\n",
    "    train_data, val_data, test_data = dataset[0]\n",
    "    \n",
    "    seed_everything(seed)\n",
    "    \n",
    "    num_feat = dataset.num_features\n",
    "    model = Net(num_feat, 128, 64, conv_type).to(device)\n",
    "    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)\n",
    "    criterion = torch.nn.BCEWithLogitsLoss()\n",
    "    \n",
    "    x_pcc = []\n",
    "    y_pcc = []\n",
    "    \n",
    "    x_nrmse = []\n",
    "    y_nrmse = []\n",
    "\n",
    "    best_val_auc = final_test_auc = 0\n",
    "    for epoch in range(1, 101):\n",
    "        loss = train(model, train_data)\n",
    "        val_auc = test(model, val_data)\n",
    "        test_auc = test(model, test_data)\n",
    "        if val_auc > best_val_auc:\n",
    "            best_val_auc = val_auc\n",
    "            final_test_auc = test_auc\n",
    "\n",
    "    print(f'Final Test: {final_test_auc:.4f}')\n",
    "    auc_list.append(final_test_auc)\n",
    "    \n",
    "    model.eval()\n",
    "    x = test_data.x\n",
    "    rep = model.encode(x, test_data.edge_index)\n",
    "\n",
    "    z = x\n",
    "    for conv in model.convs:\n",
    "        z = z @ conv.lin.weight.detach().t()\n",
    "    taylor_rep = torch.zeros_like(z)\n",
    "    \n",
    "    test_edge_index = add_remaining_self_loops(test_data.edge_index, num_nodes=test_data.x.size(0))[0]\n",
    "    print(\"Is undirected?\", is_undirected(test_edge_index))\n",
    "    print(\"Contains self-loops?\", contains_self_loops(test_edge_index))\n",
    "    \n",
    "    if inspect_deviations:\n",
    "        feat_rep = torch.zeros_like(z)\n",
    "        deg_sim = torch.zeros((test_data.x.size(0), test_data.x.size(0))).to(device)\n",
    "    \n",
    "    for c in range(int(test_data.y.max()) + 1):\n",
    "        c_mask = torch.nonzero(test_data.y.flatten() == c).flatten()\n",
    "        c_edge_index = subgraph(c_mask, test_edge_index, relabel_nodes=True, num_nodes=test_data.x.size(0))[0]\n",
    "        \n",
    "        c_edge_index = coalesce(c_edge_index)\n",
    "        \n",
    "        c_section = to_dense_adj(c_edge_index)[0]\n",
    "\n",
    "        deg_i = c_section.sum(dim=1)\n",
    "        deg_j = c_section.sum(dim=0)\n",
    "\n",
    "        if conv_type == \"sym\":\n",
    "            eigv_i = torch.sqrt(deg_i / deg_i.sum())\n",
    "            eigv_j = torch.sqrt(deg_j / deg_j.sum())\n",
    "        else:\n",
    "            eigv_i = torch.ones_like(deg_i)\n",
    "            eigv_j = deg_j / deg_j.sum()\n",
    "\n",
    "        taylor_rep[c_mask] = eigv_i.reshape(-1, 1) @ eigv_j.reshape(1, -1) @ z[c_mask]\n",
    "        \n",
    "        if inspect_deviations:\n",
    "            feat_rep[c_mask] = torch.sqrt(torch.ones_like(deg_i) / deg_i.sum()).reshape(-1, 1) @ eigv_j.reshape(1, -1) @ z[c_mask]\n",
    "            prod = torch.sqrt(deg_i.reshape(-1, 1) @ deg_j.reshape(1, -1))\n",
    "            for i, idx in enumerate(c_mask):\n",
    "                deg_sim[idx, c_mask] = prod[i]\n",
    "            \n",
    "    if inspect_deviations:\n",
    "        label_adj = to_dense_adj(coalesce(test_data.edge_label_index), max_num_nodes=test_data.x.size(0))[0]\n",
    "        deg_sim = deg_sim.cpu()\n",
    "\n",
    "    for c in range(int(test_data.y.max()) + 1):\n",
    "        c_mask = torch.nonzero(test_data.y.flatten() == c).flatten()\n",
    "        c_edge_label_index = subgraph(c_mask, test_data.edge_label_index, num_nodes=test_data.x.size(0))[0]\n",
    "        \n",
    "        x = model.decode(rep, c_edge_label_index).detach().cpu().numpy()\n",
    "        y = model.decode(taylor_rep, c_edge_label_index).detach().cpu().numpy()\n",
    "        \n",
    "        if len(x) <= 1:\n",
    "            continue\n",
    "        \n",
    "        rho, _, _, _ = np.linalg.lstsq(y[:,np.newaxis], x)\n",
    "        y *= rho[0]\n",
    "\n",
    "        x_list[c][seed_idx] = x.tolist()\n",
    "        y_list[c][seed_idx] = y.tolist()\n",
    "        \n",
    "        x_pcc.extend(x.tolist())\n",
    "        y_pcc.extend(y.tolist())\n",
    "        \n",
    "        x_nrmse.extend(x.tolist())\n",
    "        y_nrmse.extend(y.tolist())\n",
    "\n",
    "        if inspect_deviations:\n",
    "            c_mask = ((test_data.y.reshape(-1, 1) == c) & (test_data.y.reshape(1, -1) == c)) & (label_adj == 1)\n",
    "            colors = (x - y).tolist()\n",
    "            colors_list += colors\n",
    "\n",
    "            feat_sim = model.decode(feat_rep, c_edge_label_index).detach().cpu()\n",
    "            feat_sim_list += feat_sim.numpy().tolist()\n",
    "            deg_sim_list += deg_sim[c_mask].numpy().tolist()\n",
    "        \n",
    "        if dataset_name in fairness_datasets:\n",
    "            x_fairness_scores[c].append(fairness(c, test_data, test_data.edge_label_index, model.decode(rep, test_data.edge_label_index)).item())\n",
    "            y_fairness_scores[c].append(fairness(c, test_data, test_data.edge_label_index, model.decode(taylor_rep, test_data.edge_label_index) * rho[0]).item())\n",
    "    \n",
    "    pcc, _ = pearsonr(x_pcc, y_pcc)\n",
    "    pcc_list.append(pcc)\n",
    "    \n",
    "    nrmse = mean_squared_error(x_nrmse, y_nrmse, squared=False)\n",
    "    nrmse = nrmse / (max(x_nrmse) - min(x_nrmse))\n",
    "    nrmse_list.append(nrmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd110b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "if inspect_deviations:\n",
    "    feat_sim_list = np.array(feat_sim_list)\n",
    "    deg_sim_list = np.array(deg_sim_list)\n",
    "    colors_list = np.abs(np.array(colors_list))\n",
    "\n",
    "    tau = np.percentile(colors_list, 0)\n",
    "\n",
    "    print(\"Feature Similarity - Deviation:\", spearmanr(feat_sim_list[colors_list > tau], colors_list[colors_list > tau])[0])\n",
    "    print(\"Degree Product - Deviation:\", spearmanr(deg_sim_list[colors_list > tau], colors_list[colors_list > tau])[0])\n",
    "\n",
    "    plt.scatter(feat_sim_list[colors_list > tau], colors_list[colors_list > tau])\n",
    "    plt.xlabel(\"Feature similarity\")\n",
    "    plt.ylabel(\"Absolute deviation\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.scatter(deg_sim_list[colors_list > tau], colors_list[colors_list > tau])\n",
    "    plt.xlabel(\"Degree product\")\n",
    "    plt.ylabel(\"Absolute deviation\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.scatter(feat_sim_list[colors_list > tau], deg_sim_list[colors_list > tau], c=colors_list[colors_list > tau])\n",
    "    plt.xlabel(\"Feature similarity\")\n",
    "    plt.ylabel(\"Degree product\")\n",
    "    plt.colorbar()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f14bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "plots_data = {\n",
    "    'auc': auc_list,\n",
    "    'pcc': pcc_list,\n",
    "    'nrmse': nrmse_list,\n",
    "    'x': x_list,\n",
    "    'y': y_list\n",
    "}\n",
    "\n",
    "with open(\"plots_data/{}_{}_{}.pkl\".format(dataset.name, conv_type, k), 'wb') as handle:\n",
    "    pkl.dump(plots_data, handle, protocol=pkl.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a9f97e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"plots_data/{}_{}_{}.pkl\".format(dataset_name, conv_type, k), 'rb') as handle:\n",
    "    plots_data = pkl.load(handle)\n",
    "\n",
    "auc_list = plots_data['auc']\n",
    "pcc_list = plots_data['pcc']\n",
    "nrmse_list = plots_data['nrmse']\n",
    "x_list = plots_data['x']\n",
    "y_list = plots_data['y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a343c9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "auc_all = torch.tensor(auc_list)\n",
    "auc_avg = auc_all.mean().item()\n",
    "auc_std = auc_all.std().item()\n",
    "auc_str = \"{0:.3f} ± \".format(auc_avg) + \"{0:.3f}\".format(auc_std)\n",
    "print(auc_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b10067b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pcc_all = torch.tensor(pcc_list)\n",
    "pcc_avg = pcc_all.mean().item()\n",
    "pcc_std = pcc_all.std().item()\n",
    "pcc_str = \"{0:.3f} ± \".format(pcc_avg) + \"{0:.3f}\".format(pcc_std)\n",
    "print(pcc_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88896f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "nrmse_all = torch.tensor(nrmse_list)\n",
    "nrmse_avg = nrmse_all.mean().item()\n",
    "nrmse_std = nrmse_all.std().item()\n",
    "nrmse_str = \"{0:.3f} ± \".format(nrmse_avg) + \"{0:.3f}\".format(nrmse_std)\n",
    "print(nrmse_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be965826",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_title = \" \".join([w.upper() for w in re.split('_| ', dataset.name)])\n",
    "\n",
    "with open(\"results.csv\", \"a\") as results_file:\n",
    "    results_file.write(plot_title + \",\")\n",
    "    results_file.write(nrmse_str + \",\")\n",
    "    results_file.write(pcc_str + \",\")\n",
    "    results_file.write(auc_str + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59be7a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plotly_objs = []\n",
    "\n",
    "min_avg = 0\n",
    "max_avg = 0\n",
    "\n",
    "test_N = 0\n",
    "test_B = 0\n",
    "\n",
    "for c in range(len(x_list)):\n",
    "\n",
    "    x_all = torch.tensor(x_list[c])\n",
    "    y_all = torch.tensor(y_list[c])\n",
    "    \n",
    "    test_N += int(x_all.numel() / len(seeds))\n",
    "    if x_all.numel() > 0:\n",
    "        test_B += 1\n",
    "\n",
    "    if x_all.numel() == 0:\n",
    "        continue\n",
    "\n",
    "    for i in range(len(x_all)):\n",
    "        x, y = x_all[i], y_all[i]\n",
    "\n",
    "    x_avg = x_all.mean(dim=0)\n",
    "    y_avg = y_all.mean(dim=0)\n",
    "    \n",
    "    min_avg = min(min_avg, x_avg.min().item())\n",
    "    min_avg = min(min_avg, y_avg.min().item())\n",
    "    max_avg = max(max_avg, x_avg.max().item())\n",
    "    max_avg = max(max_avg, y_avg.max().item())\n",
    "    \n",
    "    x_range = x_all.max(dim=0).values - x_all.min(dim=0).values\n",
    "    y_range = y_all.max(dim=0).values - y_all.min(dim=0).values\n",
    "    marker_size = (0.1 * (x_range ** 2 + y_range ** 2)).tolist()\n",
    "\n",
    "    plotly_objs.extend([\n",
    "        go.Scatter(\n",
    "            x=x_avg.tolist(),\n",
    "            y=y_avg.tolist(),\n",
    "            mode='markers',\n",
    "            showlegend=False,\n",
    "            marker=dict(size=marker_size,\n",
    "                        sizemode='area',\n",
    "                        sizeref=2.*max(marker_size)/(30.**2))\n",
    "        )\n",
    "    ])\n",
    "\n",
    "dtick = int((max_avg - min_avg) / 5)\n",
    "diag_y = [min_avg - dtick // 2, max_avg + dtick // 2]\n",
    "diag_x = diag_y\n",
    "\n",
    "plotly_objs.append(go.Scatter(\n",
    "            x=diag_x,\n",
    "            y=diag_y,\n",
    "            mode='lines',\n",
    "            line={'dash': 'dash', 'color': 'black'},\n",
    "            showlegend=False\n",
    "        ))\n",
    "\n",
    "# avoid loading mathjax text\n",
    "fig=px.scatter(x=[0, 1, 2, 3, 4], y=[0, 1, 4, 9, 16])\n",
    "fig.write_image(\"plots/{}_{}_{}.png\".format(dataset_name, conv_type, k))\n",
    "time.sleep(2)\n",
    "\n",
    "plot_title = \" \".join([w.upper() for w in re.split('_| ', dataset_name)])\n",
    "plot_title += \" (N = {}, B = {})\".format(test_N, test_B)\n",
    "axis_titles = [r\"\\textrm{ link prediction score}}$\", r\"$\\Large{\\textrm{Theoretic link prediction score}}$\"]\n",
    "auc_annot = r\"\\textrm{ test AUC} = \" + auc_str + r\"$\"\n",
    "if conv_type == \"sym\":\n",
    "    axis_titles[0] = r\"$\\Large{\\Phi_s\" + axis_titles[0]\n",
    "    auc_annot = r\"$\\Phi_s\" + auc_annot\n",
    "else:\n",
    "    axis_titles[0] = r\"$\\Large{\\Phi_r\" + axis_titles[0]\n",
    "    auc_annot = r\"$\\Phi_r\" + auc_annot\n",
    "\n",
    "fig = go.Figure(plotly_objs)\n",
    "fig.update_yaxes(\n",
    "    scaleanchor=\"x\",\n",
    "    scaleratio=1,\n",
    "  )\n",
    "fig.update_layout(\n",
    "    height=600,\n",
    "    width=600,\n",
    "    margin=dict(l=5, r=5, t=5, b=5),\n",
    "    font=dict(size=20, color=\"black\"),\n",
    "    xaxis=go.layout.XAxis(\n",
    "        range=[min_avg - dtick // 2, max_avg + dtick // 2],\n",
    "        title=axis_titles[0],\n",
    "        showgrid=False,\n",
    "        zeroline=False,\n",
    "        showline=True,\n",
    "        linecolor='#000000',\n",
    "        dtick=dtick\n",
    "    ),\n",
    "    yaxis=go.layout.YAxis(\n",
    "        range=[min_avg - dtick // 2, max_avg + dtick // 2],\n",
    "        title=axis_titles[1],\n",
    "        showgrid=False,\n",
    "        zeroline=False,\n",
    "        showline=True,\n",
    "        linecolor='#000000',\n",
    "        dtick=dtick\n",
    "    ),\n",
    "    plot_bgcolor=\"white\",\n",
    "    title=dict(text=plot_title, automargin=True),\n",
    "    title_x=0.5\n",
    ")\n",
    "\n",
    "fig.write_image(\"plots/{}_{}_{}.png\".format(dataset_name, conv_type, k))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e2ed5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in fairness_datasets:\n",
    "    plots_data = {\n",
    "        'auc': auc_str,\n",
    "        'x': x_fairness_scores,\n",
    "        'y': y_fairness_scores\n",
    "    }\n",
    "\n",
    "    with open(\"plots_data/{}_{}_{}_{}.pkl\".format(dataset.name, conv_type, k, fairness_regularizer), 'wb') as handle:\n",
    "        pkl.dump(plots_data, handle, protocol=pkl.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b84f39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in fairness_datasets:\n",
    "    with open(\"plots_data/{}_{}_{}_{}.pkl\".format(dataset_name, conv_type, k, fairness_regularizer), 'rb') as handle:\n",
    "        plots_data = pkl.load(handle)\n",
    "\n",
    "    auc_str = plots_data['auc']\n",
    "    x_fairness_scores = plots_data['x']\n",
    "    y_fairness_scores = plots_data['y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7af4596e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in fairness_datasets:\n",
    "    plotly_objs = []\n",
    "    flattened_x = []\n",
    "    flattened_y = []\n",
    "    \n",
    "    for x, y in zip(x_fairness_scores, y_fairness_scores):\n",
    "        if len(x) == 0:\n",
    "            continue\n",
    "        \n",
    "        flattened_x.extend(x)\n",
    "        flattened_y.extend(y)\n",
    "        \n",
    "        plotly_objs.extend([\n",
    "            go.Scatter(\n",
    "                x=x,\n",
    "                y=y,\n",
    "                mode='markers',\n",
    "                showlegend=False,\n",
    "                marker=dict(size=12,\n",
    "                              line=dict(width=2,\n",
    "                                        color='DarkSlateGrey'))\n",
    "            )\n",
    "        ])\n",
    "        \n",
    "    min_xy = min(flattened_x + flattened_y)\n",
    "    max_xy = max(flattened_x + flattened_y)\n",
    "    dtick = round((max_xy - min_xy) / 5, 3)\n",
    "\n",
    "    plotly_objs.append(go.Scatter(\n",
    "            x=[min_xy - dtick, max_xy + dtick],\n",
    "            y=[min_xy - dtick, max_xy + dtick],\n",
    "            mode='lines',\n",
    "            line={'dash': 'dash', 'color': 'black'},\n",
    "            showlegend=False\n",
    "        ))\n",
    "    \n",
    "    fig = go.Figure(plotly_objs)\n",
    "    \n",
    "    pcc, _ = pearsonr(flattened_x, flattened_y)\n",
    "    pcc_str = \"{0:.3f}\".format(pcc)\n",
    "\n",
    "    nrmse = mean_squared_error(flattened_x, flattened_y, squared=False)\n",
    "    nrmse = nrmse / (max(flattened_x) - min(flattened_x))\n",
    "    nrmse_str = \"{0:.3f}\".format(nrmse)\n",
    "    \n",
    "    fig.update_yaxes(\n",
    "        scaleanchor=\"x\",\n",
    "        scaleratio=1,\n",
    "      )\n",
    "    fig.update_layout(\n",
    "        height=600,\n",
    "        width=600,\n",
    "        margin=dict(l=5, r=5, t=5, b=5),\n",
    "        font=dict(size=20, color=\"black\"),\n",
    "        xaxis=go.layout.XAxis(\n",
    "            range=[min_xy - dtick, max_xy + dtick],\n",
    "            title=r'$\\Large{\\Delta^{(b)}}$',\n",
    "            showgrid=False,\n",
    "            zeroline=False,\n",
    "            showline=True,\n",
    "            linecolor='#000000',\n",
    "            dtick=dtick\n",
    "        ),\n",
    "        yaxis=go.layout.YAxis(\n",
    "            range=[min_xy - dtick, max_xy + dtick],\n",
    "            title=r'$\\Large{\\widehat{\\Delta}^{(b)}}$',\n",
    "            showgrid=False,\n",
    "            zeroline=False,\n",
    "            showline=True,\n",
    "            linecolor='#000000',\n",
    "            dtick=dtick\n",
    "        ),\n",
    "        plot_bgcolor=\"white\",\n",
    "        title=dict(text=plot_title, automargin=True),\n",
    "        title_x=0.5\n",
    "    )\n",
    "    \n",
    "    fig.add_annotation(dict(x=0.55,\n",
    "                            y=0.25,\n",
    "                            text=r\"$\\Large{\\textrm{NRMSE} = \" + nrmse_str + r\"}$\",\n",
    "                            showarrow=False,\n",
    "                            textangle=0,\n",
    "                            xanchor='left',\n",
    "                            xref=\"paper\",\n",
    "                            yref=\"paper\"))\n",
    "    fig.add_annotation(dict(x=0.55,\n",
    "                            y=0.2,\n",
    "                            text=r\"$\\Large{\\textrm{PCC} = \" + pcc_str + r\"}$\",\n",
    "                            showarrow=False,\n",
    "                            textangle=0,\n",
    "                            xanchor='left',\n",
    "                            xref=\"paper\",\n",
    "                            yref=\"paper\"))\n",
    "        \n",
    "    fig.write_image(\"plots/{}_{}_{}_{}.png\".format(dataset_name, conv_type, k, fairness_regularizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e21ac2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name in fairness_datasets:\n",
    "    filtered_fairness_scores = list(filter(None, x_fairness_scores))\n",
    "\n",
    "    fairness_all = torch.tensor(filtered_fairness_scores).mean(dim=0)\n",
    "    fairness_avg = fairness_all.mean().item()\n",
    "    fairness_std = fairness_all.std().item()\n",
    "    fairness_str = \"{0:.3f} ± \".format(fairness_avg) + \"{0:.3f}\".format(fairness_std)\n",
    "    \n",
    "    with open(\"fairness_results_{}.csv\".format(conv_type), \"a\") as results_file:\n",
    "        results_file.write(\" \".join([w.upper() for w in re.split('_| ', dataset_name)]) + \",\")\n",
    "        results_file.write(str(fairness_regularizer) + \",\")\n",
    "        results_file.write(fairness_str + \",\")\n",
    "        results_file.write(auc_str + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10cedcfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051a6c7b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db4fa1b4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
