{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "\n",
    "from torch_geometric.datasets import Planetoid\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.utils import negative_sampling\n",
    "from torch_geometric.nn import GCNConv\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import roc_auc_score, accuracy_score\n",
    "\n",
    "from utils import (\n",
    "    get_link_labels,\n",
    "    prediction_fairness,\n",
    ")\n",
    "\n",
    "from torch_geometric.utils import train_test_split_edges\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(GCN, self).__init__()\n",
    "        self.conv1 = GCNConv(in_channels, 128)\n",
    "        self.conv2 = GCNConv(128, out_channels)\n",
    "\n",
    "    def encode(self, x, pos_edge_index):\n",
    "        x = F.relu(self.conv1(x, pos_edge_index))\n",
    "        x = self.conv2(x, pos_edge_index)\n",
    "        return x\n",
    "\n",
    "    def decode(self, z, pos_edge_index, neg_edge_index):\n",
    "        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)\n",
    "        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)\n",
    "        return logits, edge_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"citeseer\" #\"cora\" \"pubmed\"\n",
    "path = osp.join(osp.dirname(osp.realpath('__file__')), \"..\", \"data\", dataset)\n",
    "dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_seeds = [0,1,2,3,4,5]\n",
    "acc_auc = []\n",
    "fairness = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Loss: 0.6261, Val: 0.8056, Test: 0.8075\n",
      "Epoch: 020, Loss: 0.5561, Val: 0.8056, Test: 0.8075\n"
     ]
    }
   ],
   "source": [
    "delta = 0.1\n",
    "budget=[]\n",
    "for random_seed in test_seeds:\n",
    "    np.random.seed(random_seed)\n",
    "    data = dataset[0]\n",
    "    protected_attribute = data.y\n",
    "    data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
    "    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)\n",
    "    data = data.to(device)\n",
    "\n",
    "    num_classes = len(np.unique(protected_attribute))\n",
    "    N = data.num_nodes\n",
    "    \n",
    "    \n",
    "    epochs = 101\n",
    "    model = GCN(data.num_features, 128).to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    \n",
    "\n",
    "    Y = torch.LongTensor(protected_attribute).to(device)\n",
    "    Y_aux = (\n",
    "        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]\n",
    "    ).to(device)\n",
    "    randomization = (\n",
    "        torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + delta\n",
    "    ).to(device)\n",
    "    \n",
    "    Y_temp=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp[Y_aux==True]=1\n",
    "    Y_temp[Y_aux==False]=0\n",
    "    Y_temp2=torch.LongTensor(Y_aux.size(0))\n",
    "    Y_temp2[Y_aux==False]=1\n",
    "    Y_temp2[Y_aux==True]=0\n",
    "    best_val_perf = test_perf = 0\n",
    "    for epoch in range(1, epochs):\n",
    "        # TRAINING    \n",
    "        neg_edges_tr = negative_sampling(\n",
    "            edge_index=data.train_pos_edge_index,\n",
    "            num_nodes=N,\n",
    "            num_neg_samples=data.train_pos_edge_index.size(1) // 2,\n",
    "                    ).to(device)\n",
    "\n",
    "        if epoch == 1 or epoch % 10 == 0:\n",
    "            sens = torch.where(randomization[epoch], Y_temp, Y_temp2)\n",
    "            keep=torch.BoolTensor(Y_aux.size(0))\n",
    "            keep[sens==1]=True\n",
    "            keep[sens==0]=False\n",
    "        if epoch ==1:\n",
    "            budget.append(len(np.where(sens==1)[0]))\n",
    "        \n",
    "        \n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        z = model.encode(data.x, data.train_pos_edge_index[:, keep])\n",
    "        link_logits, _ = model.decode(\n",
    "            z, data.train_pos_edge_index[:, keep], neg_edges_tr\n",
    "        )\n",
    "        tr_labels = get_link_labels(\n",
    "            data.train_pos_edge_index[:, keep], neg_edges_tr\n",
    "        ).to(device)\n",
    "        \n",
    "        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # EVALUATION\n",
    "        model.eval()\n",
    "        perfs = []\n",
    "        for prefix in [\"val\", \"test\"]:\n",
    "            pos_edge_index = data[f\"{prefix}_pos_edge_index\"]\n",
    "            neg_edge_index = data[f\"{prefix}_neg_edge_index\"]\n",
    "            with torch.no_grad():\n",
    "                z = model.encode(data.x, data.train_pos_edge_index)\n",
    "                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)\n",
    "            link_probs = link_logits.sigmoid()\n",
    "            link_labels = get_link_labels(pos_edge_index, neg_edge_index)\n",
    "            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())\n",
    "            perfs.append(auc)\n",
    "\n",
    "        val_perf, tmp_test_perf = perfs\n",
    "        if val_perf > best_val_perf:\n",
    "            best_val_perf = val_perf\n",
    "            test_perf = tmp_test_perf\n",
    "        if epoch%10==0:\n",
    "            log = \"Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}\"\n",
    "            print(log.format(epoch, loss, best_val_perf, test_perf))\n",
    "    # FAIRNESS\n",
    "    auc = test_perf\n",
    "    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]\n",
    "    best_acc = 0\n",
    "    best_cut = 0.5\n",
    "    for i in cut:\n",
    "        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)\n",
    "        if acc > best_acc:\n",
    "            best_acc = acc\n",
    "            best_cut = i\n",
    "    f = prediction_fairness(\n",
    "        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()\n",
    "    )\n",
    "    acc_auc.append([best_acc * 100, auc * 100])\n",
    "    fairness.append([x * 100 for x in f])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ma = np.mean(np.asarray(acc_auc), axis=0)\n",
    "mf = np.mean(np.asarray(fairness), axis=0)\n",
    "\n",
    "sa = np.std(np.asarray(acc_auc), axis=0)\n",
    "sf = np.std(np.asarray(fairness), axis=0)\n",
    "\n",
    "print(f\"ACC: {ma[0]:2f} +- {sa[0]:2f}\")\n",
    "print(f\"AUC: {ma[1]:2f} +- {sa[1]:2f}\")\n",
    "\n",
    "print(f\"DP mix: {mf[0]:2f} +- {sf[0]:2f}\")\n",
    "print(f\"EoP mix: {mf[1]:2f} +- {sf[1]:2f}\")\n",
    "print(f\"DP group: {mf[2]:2f} +- {sf[2]:2f}\")\n",
    "print(f\"EoP group: {mf[3]:2f} +- {sf[3]:2f}\")\n",
    "print(f\"DP sub: {mf[4]:2f} +- {sf[4]:2f}\")\n",
    "print(f\"EoP sub: {mf[5]:2f} +- {sf[5]:2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
