{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from sklearn.manifold import TSNE\n",
    "from torch_geometric.utils import negative_sampling\n",
    "from fairn2v import Node2Vec\n",
    "from os.path import join, dirname, realpath\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import preprocessing\n",
    "from utils import (\n",
    "    encode_classes,\n",
    "    emb_fairness,\n",
    "    train_rn2v_adaptive,\n",
    "    emblink_fairness,\n",
    ")\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "np.random.seed(332)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = dict(\n",
    "    learning_rate=0.01, walk_length=30, walks_per_node=10, p=0.50, q=0.75, delta=0.3\n",
    ")\n",
    "dataset_path = join(dirname(realpath(\"__file__\")), \"data\", \"dblp\")\n",
    "\n",
    "with open(\n",
    "    join(dataset_path, \"author-author.csv\"), mode=\"r\", encoding=\"ISO-8859-1\"\n",
    ") as file_name:\n",
    "    edges = np.genfromtxt(file_name, delimiter=\",\", dtype=int)\n",
    "\n",
    "with open(\n",
    "    join(dataset_path, \"countries.csv\"), mode=\"r\", encoding=\"ISO-8859-1\"\n",
    ") as file_name:\n",
    "    attributes = np.genfromtxt(file_name, delimiter=\",\", dtype=str)\n",
    "\n",
    "sensitive = encode_classes(attributes[:, 1])\n",
    "num_classes = len(np.unique(sensitive))\n",
    "N = sensitive.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "tuple index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-7-efaaf7387935>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msensitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m: tuple index out of range"
     ]
    }
   ],
   "source": [
    "print(sensitive.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = np.random.choice(len(edges), int(len(edges) * 0.8), replace=False)\n",
    "tr_mask = np.zeros(len(edges), dtype=bool)\n",
    "tr_mask[m] = True\n",
    "pos_edges_tr = edges[tr_mask]\n",
    "pos_edges_te = edges[~tr_mask]\n",
    "\n",
    "pos_edges_te = torch.LongTensor(pos_edges_te.T).to(device)\n",
    "neg_edges_te = negative_sampling(\n",
    "    edge_index=pos_edges_te, num_nodes=N, num_neg_samples=pos_edges_te.size(1)\n",
    ").to(device)\n",
    "\n",
    "pos_edges_tr = torch.LongTensor(pos_edges_tr.T).to(device)\n",
    "neg_edges_tr = negative_sampling(\n",
    "    edge_index=pos_edges_tr, num_nodes=N, num_neg_samples=pos_edges_tr.size(1)\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 51\n",
    "model = Node2Vec(\n",
    "    pos_edges_tr,\n",
    "    embedding_dim=128,\n",
    "    walk_length=config[\"walk_length\"],\n",
    "    context_size=10,\n",
    "    walks_per_node=config[\"walks_per_node\"],\n",
    "    p=config[\"p\"],\n",
    "    q=config[\"q\"],\n",
    "    num_negative_samples=1,\n",
    "    sparse=True,\n",
    ").to(device)\n",
    "\n",
    "loader = model.loader(batch_size=64, shuffle=True, num_workers=8)\n",
    "\n",
    "optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=config[\"learning_rate\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = torch.LongTensor(sensitive).to(device)\n",
    "Y_aux = (Y[pos_edges_tr[0, :]] != Y[pos_edges_tr[1, :]]).to(device)\n",
    "randomization = (torch.FloatTensor(epochs, Y_aux.size(0)).uniform_() < 0.5 + config[\"delta\"]).to(\n",
    "    device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, epochs):\n",
    "\n",
    "    loss = train_rn2v_adaptive(\n",
    "        model,\n",
    "        loader,\n",
    "        optimizer,\n",
    "        device,\n",
    "        pos_edges_tr,\n",
    "        Y_aux,\n",
    "        randomization[epoch],\n",
    "        N,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "scaler = preprocessing.StandardScaler()\n",
    "XB = scaler.fit_transform(model().detach().cpu())\n",
    "YB = sensitive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_rb = emb_fairness(XB, YB)\n",
    "print(node_rb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "link_rb = emblink_fairness(XB, YB, pos_edges_tr.to(\"cpu\"), pos_edges_te.to(\"cpu\"))\n",
    "print(link_rb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Pytorch",
   "language": "python",
   "name": "torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
