{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.utils import negative_sampling\n",
    "from torch_geometric.nn import GCNConv\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import roc_auc_score, accuracy_score\n",
    "\n",
    "from utils import (\n",
    "    get_link_labels,\n",
    "    prediction_fairness,\n",
    ")\n",
    "\n",
    "from torch_geometric.utils import train_test_split_edges\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(GCN, self).__init__()\n",
    "        self.conv1 = GCNConv(in_channels, 128)\n",
    "        self.conv2 = GCNConv(128, out_channels)\n",
    "\n",
    "    def encode(self, x, pos_edge_index):\n",
    "        x = F.relu(self.conv1(x, pos_edge_index))\n",
    "        x = self.conv2(x, pos_edge_index)\n",
    "        return x\n",
    "\n",
    "    def decode(self, z, pos_edge_index, neg_edge_index):\n",
    "        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)\n",
    "        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)\n",
    "        return logits, edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dropout_adj_fair(edge_index, y, sens, r, pk=0.9, pmin=0.4, pmax=1):\n",
    "    row, col = edge_index\n",
    "\n",
    "    ma= edge_index.new_full((row.size(0),), pk, dtype=torch.float)\n",
    "\n",
    "    ma[np.where((sens[row] == sens[col]) == True)[0]]= pk\n",
    "    for i,s in enumerate(np.unique(y)):\n",
    "        ma[np.where((torch.logical_and(sens[row] != sens[col], sens[row]==s*torch.ones(row.size(0)))) == True)[0]] = min(pmax,max(pk*r[s],pmin))\n",
    "    ma = torch.bernoulli(ma).to(torch.bool)\n",
    "\n",
    "    row, col = filter_adj_fair(row, col, ma)\n",
    "    edge_index = torch.stack([row, col], dim=0)\n",
    "\n",
    "    return edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_adj_fair(row, col, mask):\n",
    "    return row[mask], col[mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def graph_attrs_idel(edges, y, sens):\n",
    "    row, col = edges\n",
    "\n",
    "    inter=np.where(sens==1)[0]\n",
    "    intra=np.where(sens==0)[0]\n",
    "\n",
    "    edges=np.array(edges).T\n",
    "    r={}\n",
    "    for i,s in enumerate(np.unique(y)):\n",
    "        a=(len(np.where(y[edges[intra,0]]==s)[0]))\n",
    "        if a>0:\n",
    "            r[s]=float((len(np.where(y[edges[inter,1]]==s)[0])+len(np.where(y[edges[inter,0]]==s)[0])))/(2*len(np.where(y[edges[intra,0]]==s)[0]))\n",
    "        else:\n",
    "            print('check')\n",
    "    return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2394\n"
     ]
    }
   ],
   "source": [
    "print(len(np.where(sens==1)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4998\n"
     ]
    }
   ],
   "source": [
    "print(len(np.where(sens==0)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cora\" #\"citeseer\"  \"pubmed\"\n",
    "path = osp.join(osp.dirname(osp.realpath('__file__')), \"..\", \"data\", dataset)\n",
    "dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_seeds = [0,1,2,3,4,5]\n",
    "acc_auc = []\n",
    "fairness = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6592, Val: 0.7861, Test: 0.7919\n",
      "Epoch: 020, Loss: 0.5932, Val: 0.8319, Test: 0.8057\n",
      "Epoch: 030, Loss: 0.5483, Val: 0.8784, Test: 0.8489\n",
      "Epoch: 040, Loss: 0.5572, Val: 0.8908, Test: 0.8575\n",
      "Epoch: 050, Loss: 0.5373, Val: 0.8951, Test: 0.8575\n",
      "Epoch: 060, Loss: 0.5275, Val: 0.8979, Test: 0.8680\n",
      "Epoch: 070, Loss: 0.5311, Val: 0.9016, Test: 0.8783\n",
      "Epoch: 080, Loss: 0.5184, Val: 0.9061, Test: 0.8858\n",
      "Epoch: 090, Loss: 0.5110, Val: 0.9061, Test: 0.8858\n",
      "Epoch: 100, Loss: 0.5066, Val: 0.9067, Test: 0.8932\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6636, Val: 0.7858, Test: 0.7891\n",
      "Epoch: 020, Loss: 0.6061, Val: 0.7858, Test: 0.7891\n",
      "Epoch: 030, Loss: 0.5584, Val: 0.8547, Test: 0.8763\n",
      "Epoch: 040, Loss: 0.5510, Val: 0.8671, Test: 0.8927\n",
      "Epoch: 050, Loss: 0.5324, Val: 0.8702, Test: 0.8977\n",
      "Epoch: 060, Loss: 0.5325, Val: 0.8770, Test: 0.9022\n",
      "Epoch: 070, Loss: 0.5147, Val: 0.8794, Test: 0.9035\n",
      "Epoch: 080, Loss: 0.5156, Val: 0.8794, Test: 0.9035\n",
      "Epoch: 090, Loss: 0.5132, Val: 0.8814, Test: 0.9141\n",
      "Epoch: 100, Loss: 0.5049, Val: 0.8858, Test: 0.9151\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6504, Val: 0.7811, Test: 0.7964\n",
      "Epoch: 020, Loss: 0.6105, Val: 0.7811, Test: 0.7964\n",
      "Epoch: 030, Loss: 0.5902, Val: 0.7907, Test: 0.8031\n",
      "Epoch: 040, Loss: 0.5451, Val: 0.8446, Test: 0.8569\n",
      "Epoch: 050, Loss: 0.5489, Val: 0.8499, Test: 0.8650\n",
      "Epoch: 060, Loss: 0.5428, Val: 0.8564, Test: 0.8743\n",
      "Epoch: 070, Loss: 0.5258, Val: 0.8623, Test: 0.8840\n",
      "Epoch: 080, Loss: 0.5313, Val: 0.8658, Test: 0.8875\n",
      "Epoch: 090, Loss: 0.5175, Val: 0.8740, Test: 0.8910\n",
      "Epoch: 100, Loss: 0.5175, Val: 0.8793, Test: 0.8938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6654, Val: 0.8010, Test: 0.8073\n",
      "Epoch: 020, Loss: 0.6036, Val: 0.8010, Test: 0.8073\n",
      "Epoch: 030, Loss: 0.5596, Val: 0.8393, Test: 0.8554\n",
      "Epoch: 040, Loss: 0.5500, Val: 0.8563, Test: 0.8757\n",
      "Epoch: 050, Loss: 0.5336, Val: 0.8709, Test: 0.8821\n",
      "Epoch: 060, Loss: 0.5309, Val: 0.8728, Test: 0.8809\n",
      "Epoch: 070, Loss: 0.5251, Val: 0.8756, Test: 0.8840\n",
      "Epoch: 080, Loss: 0.5287, Val: 0.8773, Test: 0.8862\n",
      "Epoch: 090, Loss: 0.5260, Val: 0.8789, Test: 0.8893\n",
      "Epoch: 100, Loss: 0.5239, Val: 0.8823, Test: 0.8890\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6670, Val: 0.8029, Test: 0.8315\n",
      "Epoch: 020, Loss: 0.5942, Val: 0.8042, Test: 0.8074\n",
      "Epoch: 030, Loss: 0.5501, Val: 0.8573, Test: 0.8752\n",
      "Epoch: 040, Loss: 0.5402, Val: 0.8733, Test: 0.8911\n",
      "Epoch: 050, Loss: 0.5280, Val: 0.8821, Test: 0.8977\n",
      "Epoch: 060, Loss: 0.5211, Val: 0.8851, Test: 0.9028\n",
      "Epoch: 070, Loss: 0.5135, Val: 0.8895, Test: 0.9081\n",
      "Epoch: 080, Loss: 0.5043, Val: 0.8953, Test: 0.9109\n",
      "Epoch: 090, Loss: 0.5152, Val: 0.8974, Test: 0.9110\n",
      "Epoch: 100, Loss: 0.5000, Val: 0.8981, Test: 0.9110\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6594, Val: 0.7844, Test: 0.7955\n",
      "Epoch: 020, Loss: 0.5861, Val: 0.8048, Test: 0.7950\n",
      "Epoch: 030, Loss: 0.5590, Val: 0.8609, Test: 0.8409\n",
      "Epoch: 040, Loss: 0.5418, Val: 0.8766, Test: 0.8658\n",
      "Epoch: 050, Loss: 0.5269, Val: 0.8862, Test: 0.8750\n",
      "Epoch: 060, Loss: 0.5252, Val: 0.8926, Test: 0.8799\n",
      "Epoch: 070, Loss: 0.5116, Val: 0.8941, Test: 0.8802\n",
      "Epoch: 080, Loss: 0.5202, Val: 0.8990, Test: 0.8849\n",
      "Epoch: 090, Loss: 0.5089, Val: 0.9054, Test: 0.8878\n",
      "Epoch: 100, Loss: 0.5089, Val: 0.9104, Test: 0.8919\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    }
   ],
   "source": [
    "delta = 0.28\n",
    "budget=[]\n",
    "for random_seed in test_seeds:\n",
    "    np.random.seed(random_seed)\n",
    "    data = dataset[0]\n",
    "    protected_attribute = data.y\n",
    "    data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)\n",
    "    data = data.to(device)\n",
    "\n",
    "    num_classes = len(np.unique(protected_attribute))\n",
    "    N = data.num_nodes\n",
    "    \n",
    "    \n",
    "    epochs = 101\n",
    "    model = GCN(data.num_features, 128).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    \n",
    "\n",
    "    Y = torch.LongTensor(protected_attribute).to(device)\n",
    "    Y_aux = (\n",
    "        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]\n",
    "    ).to(device)\n",
    "    randomization = (\n",
    "        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta\n",
    "    ).to(device)\n",
    "    \n",
    "    Y_temp=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp[Y_aux==True]=1\n",
    "    Y_temp[Y_aux==False]=0\n",
    "    Y_temp2=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp2[Y_aux==False]=1\n",
    "    Y_temp2[Y_aux==True]=0\n",
    "    best_val_perf = test_perf = 0\n",
    "    for epoch in range(1, epochs):\n",
    "        # TRAINING    \n",
    "        neg_edges_tr = negative_sampling(\n",
    "            edge_index=data.train_pos_edge_index,\n",
    "            num_nodes=N,\n",
    "            num_neg_samples=data.train_pos_edge_index.size(1) // 2,\n",
    "                    ).to(device)\n",
    "\n",
    "        if epoch == 1 or epoch % 10 == 0:\n",
    "            sens = torch.where(randomization[epoch], Y_temp, Y_temp2)\n",
    "            keep=torch.BoolTensor(Y_aux.size(0))\n",
    "            keep[sens==1]=True\n",
    "            keep[sens==0]=False\n",
    "        if epoch ==1:\n",
    "            budget.append(len(np.where(sens==1)[0]))\n",
    "        \n",
    "        \n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        z = model.encode(data.x, data.train_pos_edge_index[:, keep])\n",
    "        link_logits, _ = model.decode(\n",
    "            z, data.train_pos_edge_index[:, keep], neg_edges_tr\n",
    "        )\n",
    "        tr_labels = get_link_labels(\n",
    "            data.train_pos_edge_index[:, keep], neg_edges_tr\n",
    "        ).to(device)\n",
    "        \n",
    "        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # EVALUATION\n",
    "        model.eval()\n",
    "        perfs = []\n",
    "        for prefix in [\"val\", \"test\"]:\n",
    "            pos_edge_index = data[f\"{prefix}_pos_edge_index\"]\n",
    "            neg_edge_index = data[f\"{prefix}_neg_edge_index\"]\n",
    "            with torch.no_grad():\n",
    "                z = model.encode(data.x, data.train_pos_edge_index)\n",
    "                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)\n",
    "            link_probs = link_logits.sigmoid()\n",
    "            link_labels = get_link_labels(pos_edge_index, neg_edge_index)\n",
    "            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())\n",
    "            perfs.append(auc)\n",
    "\n",
    "        val_perf, tmp_test_perf = perfs\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        if epoch%10==0:\n",
    "            log = \"Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}\"\n",
    "            print(log.format(epoch, loss, best_val_perf, test_perf))\n",
    "    # FAIRNESS\n",
    "    auc = test_perf\n",
    "    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]\n",
    "    best_acc = 0\n",
    "    best_cut = 0.5\n",
    "    for i in cut:\n",
    "        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)\n",
    "        if acc > best_acc:\n",
    "            best_acc = acc\n",
    "            best_cut = i\n",
    "    f = prediction_fairness(\n",
    "        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()\n",
    "    )\n",
    "    acc_auc.append([best_acc * 100, auc * 100])\n",
    "    fairness.append([x * 100 for x in f])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACC: 82.622433 +- 1.129072\n",
      "AUC: 89.898640 +- 1.011346\n",
      "DP mix: 53.965220 +- 2.475914\n",
      "EoP mix: 33.169862 +- 4.524290\n",
      "DP group: 12.392372 +- 3.848252\n",
      "EoP group: 12.377178 +- 3.201244\n",
      "DP sub: 90.651213 +- 3.195307\n",
      "EoP sub: 100.000000 +- 0.000000\n"
     ]
    }
   ],
   "source": [
    "ma = np.mean(np.asarray(acc_auc), axis=0)\n",
    "mf = np.mean(np.asarray(fairness), axis=0)\n",
    "\n",
    "sa = np.std(np.asarray(acc_auc), axis=0)\n",
    "sf = np.std(np.asarray(fairness), axis=0)\n",
    "\n",
    "print(f\"ACC: {ma[0]:2f} +- {sa[0]:2f}\")\n",
    "print(f\"AUC: {ma[1]:2f} +- {sa[1]:2f}\")\n",
    "\n",
    "print(f\"DP mix: {mf[0]:2f} +- {sf[0]:2f}\")\n",
    "print(f\"EoP mix: {mf[1]:2f} +- {sf[1]:2f}\")\n",
    "print(f\"DP group: {mf[2]:2f} +- {sf[2]:2f}\")\n",
    "print(f\"EoP group: {mf[3]:2f} +- {sf[3]:2f}\")\n",
    "print(f\"DP sub: {mf[4]:2f} +- {sf[4]:2f}\")\n",
    "print(f\"EoP sub: {mf[5]:2f} +- {sf[5]:2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2499, 2366, 2446, 2398, 2421, 2469]\n"
     ]
    }
   ],
   "source": [
    "print(budget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_seeds = [0,1,2,3,4,5]\n",
    "acc_auc = []\n",
    "fairness = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6608, Val: 0.8181, Test: 0.8112\n",
      "Epoch: 020, Loss: 0.5784, Val: 0.8409, Test: 0.8269\n",
      "Epoch: 030, Loss: 0.5434, Val: 0.8695, Test: 0.8698\n",
      "Epoch: 040, Loss: 0.5237, Val: 0.8868, Test: 0.8844\n",
      "Epoch: 050, Loss: 0.5261, Val: 0.8912, Test: 0.8850\n",
      "Epoch: 060, Loss: 0.5227, Val: 0.8941, Test: 0.8879\n",
      "Epoch: 070, Loss: 0.5210, Val: 0.8941, Test: 0.8879\n",
      "Epoch: 080, Loss: 0.5180, Val: 0.8943, Test: 0.8909\n",
      "Epoch: 090, Loss: 0.5099, Val: 0.8943, Test: 0.8909\n",
      "Epoch: 100, Loss: 0.5094, Val: 0.8967, Test: 0.8988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6701, Val: 0.8661, Test: 0.8655\n",
      "Epoch: 020, Loss: 0.5903, Val: 0.8661, Test: 0.8655\n",
      "Epoch: 030, Loss: 0.5546, Val: 0.8795, Test: 0.8795\n",
      "Epoch: 040, Loss: 0.5393, Val: 0.8982, Test: 0.9031\n",
      "Epoch: 050, Loss: 0.5261, Val: 0.9055, Test: 0.9083\n",
      "Epoch: 060, Loss: 0.5287, Val: 0.9089, Test: 0.9096\n",
      "Epoch: 070, Loss: 0.5235, Val: 0.9129, Test: 0.9152\n",
      "Epoch: 080, Loss: 0.5209, Val: 0.9153, Test: 0.9155\n",
      "Epoch: 090, Loss: 0.5123, Val: 0.9173, Test: 0.9171\n",
      "Epoch: 100, Loss: 0.5116, Val: 0.9183, Test: 0.9181\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6563, Val: 0.7997, Test: 0.7987\n",
      "Epoch: 020, Loss: 0.5877, Val: 0.8118, Test: 0.8322\n",
      "Epoch: 030, Loss: 0.5590, Val: 0.8668, Test: 0.8658\n",
      "Epoch: 040, Loss: 0.5366, Val: 0.8753, Test: 0.8834\n",
      "Epoch: 050, Loss: 0.5308, Val: 0.8865, Test: 0.8959\n",
      "Epoch: 060, Loss: 0.5270, Val: 0.8992, Test: 0.9068\n",
      "Epoch: 070, Loss: 0.5207, Val: 0.9021, Test: 0.9090\n",
      "Epoch: 080, Loss: 0.5134, Val: 0.9035, Test: 0.9099\n",
      "Epoch: 090, Loss: 0.5149, Val: 0.9035, Test: 0.9099\n",
      "Epoch: 100, Loss: 0.5109, Val: 0.9042, Test: 0.9118\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6593, Val: 0.8678, Test: 0.8317\n",
      "Epoch: 020, Loss: 0.5824, Val: 0.8678, Test: 0.8317\n",
      "Epoch: 030, Loss: 0.5589, Val: 0.8678, Test: 0.8317\n",
      "Epoch: 040, Loss: 0.5443, Val: 0.8859, Test: 0.8859\n",
      "Epoch: 050, Loss: 0.5375, Val: 0.8940, Test: 0.8981\n",
      "Epoch: 060, Loss: 0.5210, Val: 0.8979, Test: 0.9057\n",
      "Epoch: 070, Loss: 0.5210, Val: 0.9012, Test: 0.9080\n",
      "Epoch: 080, Loss: 0.5249, Val: 0.9012, Test: 0.9080\n",
      "Epoch: 090, Loss: 0.5170, Val: 0.9012, Test: 0.9080\n",
      "Epoch: 100, Loss: 0.5168, Val: 0.9012, Test: 0.9080\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6644, Val: 0.8318, Test: 0.8318\n",
      "Epoch: 020, Loss: 0.5810, Val: 0.8452, Test: 0.8442\n",
      "Epoch: 030, Loss: 0.5433, Val: 0.8802, Test: 0.8794\n",
      "Epoch: 040, Loss: 0.5338, Val: 0.8975, Test: 0.8920\n",
      "Epoch: 050, Loss: 0.5400, Val: 0.9039, Test: 0.8964\n",
      "Epoch: 060, Loss: 0.5220, Val: 0.9054, Test: 0.8994\n",
      "Epoch: 070, Loss: 0.5188, Val: 0.9078, Test: 0.9010\n",
      "Epoch: 080, Loss: 0.5212, Val: 0.9096, Test: 0.9038\n",
      "Epoch: 090, Loss: 0.5094, Val: 0.9110, Test: 0.9064\n",
      "Epoch: 100, Loss: 0.5124, Val: 0.9110, Test: 0.9064\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6623, Val: 0.8665, Test: 0.8589\n",
      "Epoch: 020, Loss: 0.5739, Val: 0.8688, Test: 0.8576\n",
      "Epoch: 030, Loss: 0.5533, Val: 0.8771, Test: 0.8714\n",
      "Epoch: 040, Loss: 0.5453, Val: 0.8803, Test: 0.8961\n",
      "Epoch: 050, Loss: 0.5342, Val: 0.8928, Test: 0.8989\n",
      "Epoch: 060, Loss: 0.5268, Val: 0.8968, Test: 0.9052\n",
      "Epoch: 070, Loss: 0.5241, Val: 0.8982, Test: 0.9051\n",
      "Epoch: 080, Loss: 0.5190, Val: 0.8982, Test: 0.9051\n",
      "Epoch: 090, Loss: 0.5161, Val: 0.9020, Test: 0.9054\n",
      "Epoch: 100, Loss: 0.5079, Val: 0.9103, Test: 0.9101\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    }
   ],
   "source": [
    "delta = 0.28\n",
    "\n",
    "for random_seed in test_seeds:\n",
    "\n",
    "    np.random.seed(random_seed)\n",
    "    data = dataset[0]\n",
    "    protected_attribute = data.y\n",
    "    data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)\n",
    "    data = data.to(device)\n",
    "\n",
    "    num_classes = len(np.unique(protected_attribute))\n",
    "    N = data.num_nodes\n",
    "    \n",
    "    \n",
    "    epochs = 101\n",
    "    model = GCN(data.num_features, 128).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    \n",
    "\n",
    "    Y = torch.LongTensor(protected_attribute).to(device)\n",
    "    Y_aux = (\n",
    "        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]\n",
    "    ).to(device)\n",
    "    randomization = (\n",
    "        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta\n",
    "    ).to(device)\n",
    "    \n",
    "    best_val_perf = test_perf = 0\n",
    "    for epoch in range(1, epochs):\n",
    "        # TRAINING    \n",
    "        neg_edges_tr = negative_sampling(\n",
    "            edge_index=data.train_pos_edge_index,\n",
    "            num_nodes=N,\n",
    "            num_neg_samples=data.train_pos_edge_index.size(1) // 2,\n",
    "                    ).to(device)\n",
    "        \n",
    "        arr = np.arange(data.train_pos_edge_index.size(1))\n",
    "        np.random.shuffle(arr)\n",
    "        used_edges=arr[:budget[random_seed]]\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        z = model.encode(data.x, data.train_pos_edge_index[:,used_edges])\n",
    "        link_logits, _ = model.decode(\n",
    "            z, data.train_pos_edge_index[:,used_edges], neg_edges_tr\n",
    "        )\n",
    "        tr_labels = get_link_labels(\n",
    "            data.train_pos_edge_index[:,used_edges], neg_edges_tr\n",
    "        ).to(device)\n",
    "        \n",
    "        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # EVALUATION\n",
    "        model.eval()\n",
    "        perfs = []\n",
    "        for prefix in [\"val\", \"test\"]:\n",
    "            pos_edge_index = data[f\"{prefix}_pos_edge_index\"]\n",
    "            neg_edge_index = data[f\"{prefix}_neg_edge_index\"]\n",
    "            with torch.no_grad():\n",
    "                z = model.encode(data.x, data.train_pos_edge_index)\n",
    "                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)\n",
    "            link_probs = link_logits.sigmoid()\n",
    "            link_labels = get_link_labels(pos_edge_index, neg_edge_index)\n",
    "            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())\n",
    "            perfs.append(auc)\n",
    "\n",
    "        val_perf, tmp_test_perf = perfs\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        if epoch%10==0:\n",
    "            log = \"Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}\"\n",
    "            print(log.format(epoch, loss, best_val_perf, test_perf))\n",
    "    # FAIRNESS\n",
    "    auc = test_perf\n",
    "    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]\n",
    "    best_acc = 0\n",
    "    best_cut = 0.5\n",
    "    for i in cut:\n",
    "        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)\n",
    "        if acc > best_acc:\n",
    "            best_acc = acc\n",
    "            best_cut = i\n",
    "    f = prediction_fairness(\n",
    "        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()\n",
    "    )\n",
    "    acc_auc.append([best_acc * 100, auc * 100])\n",
    "    fairness.append([x * 100 for x in f])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACC: 83.009479 +- 1.007132\n",
      "AUC: 90.887424 +- 0.580346\n",
      "DP mix: 57.963170 +- 2.598799\n",
      "EoP mix: 39.591562 +- 5.329511\n",
      "DP group: 11.625557 +- 3.373486\n",
      "EoP group: 16.073516 +- 3.120178\n",
      "DP sub: 90.435491 +- 3.856193\n",
      "EoP sub: 98.888889 +- 2.484520\n"
     ]
    }
   ],
   "source": [
    "ma = np.mean(np.asarray(acc_auc), axis=0)\n",
    "mf = np.mean(np.asarray(fairness), axis=0)\n",
    "\n",
    "sa = np.std(np.asarray(acc_auc), axis=0)\n",
    "sf = np.std(np.asarray(fairness), axis=0)\n",
    "\n",
    "print(f\"ACC: {ma[0]:2f} +- {sa[0]:2f}\")\n",
    "print(f\"AUC: {ma[1]:2f} +- {sa[1]:2f}\")\n",
    "\n",
    "print(f\"DP mix: {mf[0]:2f} +- {sf[0]:2f}\")\n",
    "print(f\"EoP mix: {mf[1]:2f} +- {sf[1]:2f}\")\n",
    "print(f\"DP group: {mf[2]:2f} +- {sf[2]:2f}\")\n",
    "print(f\"EoP group: {mf[3]:2f} +- {sf[3]:2f}\")\n",
    "print(f\"DP sub: {mf[4]:2f} +- {sf[4]:2f}\")\n",
    "print(f\"EoP sub: {mf[5]:2f} +- {sf[5]:2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_seeds = [0,1,2,3,4,5]\n",
    "acc_auc = []\n",
    "fairness = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "print(budget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.47959183673469385\n"
     ]
    }
   ],
   "source": [
    "print(r[6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6928, Val: 0.7985, Test: 0.8090\n",
      "Epoch: 020, Loss: 0.6812, Val: 0.8373, Test: 0.8433\n",
      "Epoch: 030, Loss: 0.6488, Val: 0.8373, Test: 0.8433\n",
      "Epoch: 040, Loss: 0.6289, Val: 0.8438, Test: 0.8441\n",
      "Epoch: 050, Loss: 0.6229, Val: 0.8553, Test: 0.8676\n",
      "Epoch: 060, Loss: 0.6154, Val: 0.8667, Test: 0.8819\n",
      "Epoch: 070, Loss: 0.6059, Val: 0.8758, Test: 0.8935\n",
      "Epoch: 080, Loss: 0.6055, Val: 0.8853, Test: 0.8982\n",
      "Epoch: 090, Loss: 0.5953, Val: 0.8903, Test: 0.8987\n",
      "Epoch: 100, Loss: 0.5927, Val: 0.8990, Test: 0.9065\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6930, Val: 0.8541, Test: 0.8278\n",
      "Epoch: 020, Loss: 0.6889, Val: 0.8750, Test: 0.8447\n",
      "Epoch: 030, Loss: 0.6424, Val: 0.8750, Test: 0.8447\n",
      "Epoch: 040, Loss: 0.6423, Val: 0.8750, Test: 0.8447\n",
      "Epoch: 050, Loss: 0.6344, Val: 0.8800, Test: 0.8613\n",
      "Epoch: 060, Loss: 0.6205, Val: 0.8936, Test: 0.8780\n",
      "Epoch: 070, Loss: 0.6114, Val: 0.9008, Test: 0.8827\n",
      "Epoch: 080, Loss: 0.6071, Val: 0.9080, Test: 0.8872\n",
      "Epoch: 090, Loss: 0.6065, Val: 0.9105, Test: 0.8919\n",
      "Epoch: 100, Loss: 0.6049, Val: 0.9118, Test: 0.8945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6924, Val: 0.8248, Test: 0.8216\n",
      "Epoch: 020, Loss: 0.6725, Val: 0.8306, Test: 0.8241\n",
      "Epoch: 030, Loss: 0.6611, Val: 0.8306, Test: 0.8241\n",
      "Epoch: 040, Loss: 0.6374, Val: 0.8306, Test: 0.8241\n",
      "Epoch: 050, Loss: 0.6318, Val: 0.8515, Test: 0.8598\n",
      "Epoch: 060, Loss: 0.6161, Val: 0.8658, Test: 0.8769\n",
      "Epoch: 070, Loss: 0.6097, Val: 0.8738, Test: 0.8851\n",
      "Epoch: 080, Loss: 0.6016, Val: 0.8780, Test: 0.8860\n",
      "Epoch: 090, Loss: 0.6072, Val: 0.8809, Test: 0.8917\n",
      "Epoch: 100, Loss: 0.5981, Val: 0.8809, Test: 0.8917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6930, Val: 0.8120, Test: 0.8012\n",
      "Epoch: 020, Loss: 0.6871, Val: 0.8467, Test: 0.8347\n",
      "Epoch: 030, Loss: 0.6491, Val: 0.8467, Test: 0.8347\n",
      "Epoch: 040, Loss: 0.6277, Val: 0.8600, Test: 0.8708\n",
      "Epoch: 050, Loss: 0.6213, Val: 0.8793, Test: 0.8878\n",
      "Epoch: 060, Loss: 0.6156, Val: 0.8885, Test: 0.8946\n",
      "Epoch: 070, Loss: 0.6061, Val: 0.8916, Test: 0.9013\n",
      "Epoch: 080, Loss: 0.6108, Val: 0.9006, Test: 0.9062\n",
      "Epoch: 090, Loss: 0.5978, Val: 0.9058, Test: 0.9095\n",
      "Epoch: 100, Loss: 0.5975, Val: 0.9063, Test: 0.9089\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6930, Val: 0.8029, Test: 0.7835\n",
      "Epoch: 020, Loss: 0.6865, Val: 0.8367, Test: 0.8111\n",
      "Epoch: 030, Loss: 0.6708, Val: 0.8367, Test: 0.8111\n",
      "Epoch: 040, Loss: 0.6408, Val: 0.8628, Test: 0.8538\n",
      "Epoch: 050, Loss: 0.6240, Val: 0.8841, Test: 0.8710\n",
      "Epoch: 060, Loss: 0.6132, Val: 0.9035, Test: 0.8905\n",
      "Epoch: 070, Loss: 0.6058, Val: 0.9105, Test: 0.8980\n",
      "Epoch: 080, Loss: 0.6072, Val: 0.9126, Test: 0.9005\n",
      "Epoch: 090, Loss: 0.5983, Val: 0.9126, Test: 0.9005\n",
      "Epoch: 100, Loss: 0.5962, Val: 0.9127, Test: 0.9076\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6930, Val: 0.7909, Test: 0.7973\n",
      "Epoch: 020, Loss: 0.6881, Val: 0.8401, Test: 0.8508\n",
      "Epoch: 030, Loss: 0.6535, Val: 0.8401, Test: 0.8508\n",
      "Epoch: 040, Loss: 0.6376, Val: 0.8401, Test: 0.8508\n",
      "Epoch: 050, Loss: 0.6315, Val: 0.8751, Test: 0.8698\n",
      "Epoch: 060, Loss: 0.6177, Val: 0.8938, Test: 0.8876\n",
      "Epoch: 070, Loss: 0.6065, Val: 0.9008, Test: 0.8968\n",
      "Epoch: 080, Loss: 0.6078, Val: 0.9029, Test: 0.8969\n",
      "Epoch: 090, Loss: 0.6043, Val: 0.9061, Test: 0.9021\n",
      "Epoch: 100, Loss: 0.5935, Val: 0.9108, Test: 0.9090\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found 28 subgroups. Evaluation may be slow\n",
      "Found 28 subgroups. Evaluation may be slow\n"
     ]
    }
   ],
   "source": [
    "delta = 0.28\n",
    "budget=[]\n",
    "for random_seed in test_seeds:\n",
    "    np.random.seed(random_seed)\n",
    "    data = dataset[0]\n",
    "    protected_attribute = data.y\n",
    "    data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)\n",
    "    data = data.to(device)\n",
    "\n",
    "    num_classes = len(np.unique(protected_attribute))\n",
    "    N = data.num_nodes\n",
    "    \n",
    "    \n",
    "    epochs = 101\n",
    "    model = GCN(data.num_features, 128).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    \n",
    "\n",
    "    Y = torch.LongTensor(protected_attribute).to(device)\n",
    "    \n",
    "    Y_aux = (\n",
    "        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]\n",
    "    ).to(device)\n",
    "    randomization = (\n",
    "        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta\n",
    "    ).to(device)\n",
    "    \n",
    "    Y_temp=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp[Y_aux==True]=1\n",
    "    Y_temp[Y_aux==False]=0\n",
    "    Y_temp2=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp2[Y_aux==False]=1\n",
    "    Y_temp2[Y_aux==True]=0\n",
    "    best_val_perf = test_perf = 0\n",
    "    \n",
    "    \n",
    "    for epoch in range(1, epochs):\n",
    "        # TRAINING    \n",
    "        neg_edges_tr = negative_sampling(\n",
    "            edge_index=data.train_pos_edge_index,\n",
    "            num_nodes=N,\n",
    "            num_neg_samples=data.train_pos_edge_index.size(1) // 2,\n",
    "                    ).to(device)\n",
    "\n",
    "        if epoch == 1 or epoch % 10 == 0:\n",
    "            sens = torch.where(randomization[epoch], Y_temp, Y_temp2)\n",
    "            keep=torch.BoolTensor(Y_aux.size(0))\n",
    "            keep[sens==1]=True\n",
    "            keep[sens==0]=False\n",
    "            r = graph_attrs_idel(data.train_pos_edge_index, Y,sens)\n",
    "        \n",
    "        \n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "    \n",
    "        new_edges=dropout_adj_fair(data.train_pos_edge_index, Y,sens, r, 0.2, 0, 1)\n",
    "        z = model.encode(data.x, new_edges)\n",
    "        link_logits, _ = model.decode(\n",
    "            z, new_edges, neg_edges_tr\n",
    "        )\n",
    "        tr_labels = get_link_labels(\n",
    "            new_edges, neg_edges_tr\n",
    "        ).to(device)\n",
    "        \n",
    "        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # EVALUATION\n",
    "        model.eval()\n",
    "        perfs = []\n",
    "        for prefix in [\"val\", \"test\"]:\n",
    "            pos_edge_index = data[f\"{prefix}_pos_edge_index\"]\n",
    "            neg_edge_index = data[f\"{prefix}_neg_edge_index\"]\n",
    "            with torch.no_grad():\n",
    "                z = model.encode(data.x, data.train_pos_edge_index)\n",
    "                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)\n",
    "            link_probs = link_logits.sigmoid()\n",
    "            link_labels = get_link_labels(pos_edge_index, neg_edge_index)\n",
    "            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())\n",
    "            perfs.append(auc)\n",
    "\n",
    "        val_perf, tmp_test_perf = perfs\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        if epoch%10==0:\n",
    "            log = \"Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}\"\n",
    "            print(log.format(epoch, loss, best_val_perf, test_perf))\n",
    "    # FAIRNESS\n",
    "    auc = test_perf\n",
    "    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]\n",
    "    best_acc = 0\n",
    "    best_cut = 0.5\n",
    "    for i in cut:\n",
    "        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)\n",
    "        if acc > best_acc:\n",
    "            best_acc = acc\n",
    "            best_cut = i\n",
    "    f = prediction_fairness(\n",
    "        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()\n",
    "    )\n",
    "    acc_auc.append([best_acc * 100, auc * 100])\n",
    "    fairness.append([x * 100 for x in f])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACC: 82.148499 +- 0.505035\n",
      "AUC: 90.303213 +- 0.710821\n",
      "DP mix: 54.885595 +- 2.405820\n",
      "EoP mix: 28.965345 +- 4.403350\n",
      "DP group: 12.980171 +- 2.544054\n",
      "EoP group: 16.538005 +- 2.227237\n",
      "DP sub: 88.364346 +- 4.559300\n",
      "EoP sub: 100.000000 +- 0.000000\n"
     ]
    }
   ],
   "source": [
    "ma = np.mean(np.asarray(acc_auc), axis=0)\n",
    "mf = np.mean(np.asarray(fairness), axis=0)\n",
    "\n",
    "sa = np.std(np.asarray(acc_auc), axis=0)\n",
    "sf = np.std(np.asarray(fairness), axis=0)\n",
    "\n",
    "print(f\"ACC: {ma[0]:2f} +- {sa[0]:2f}\")\n",
    "print(f\"AUC: {ma[1]:2f} +- {sa[1]:2f}\")\n",
    "\n",
    "print(f\"DP mix: {mf[0]:2f} +- {sf[0]:2f}\")\n",
    "print(f\"EoP mix: {mf[1]:2f} +- {sf[1]:2f}\")\n",
    "print(f\"DP group: {mf[2]:2f} +- {sf[2]:2f}\")\n",
    "print(f\"EoP group: {mf[3]:2f} +- {sf[3]:2f}\")\n",
    "print(f\"DP sub: {mf[4]:2f} +- {sf[4]:2f}\")\n",
    "print(f\"EoP sub: {mf[5]:2f} +- {sf[5]:2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 6374])\n"
     ]
    }
   ],
   "source": [
    "print(new_edges.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2291\n"
     ]
    }
   ],
   "source": [
    "print(budget[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
