{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cvxpy as cp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import seaborn as sns\n",
    "\n",
    "from cert_collective.cert_collective import collective_certificate_grid\n",
    "from cert_collective.utils import gcn_receptive_field_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load graph and base certificate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = np.load('data/graphs/cora_ml.npz', allow_pickle=True)\n",
    "\n",
    "# Pair of binary scipy sparse matrices\n",
    "attr = graph['attr'][()] \n",
    "adj = graph['adj'][()]\n",
    "n_nodes = attr.shape[0]\n",
    "\n",
    "base_cert_grid = np.load('data/base_certs/cora_ml/attr_grid.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "heatmap = base_cert_grid.mean(0)\n",
    "sns.set_context('talk')\n",
    "sns.heatmap(heatmap, \n",
    "            cmap='Blues',\n",
    "            vmin=0, vmax=1, square=True, cbar_kws={\"shrink\": .5})\n",
    "plt.xlim(0, heatmap.shape[1])\n",
    "plt.ylim(0, heatmap.shape[0])\n",
    "plt.xlabel('Deletions')\n",
    "plt.ylabel('Additions')\n",
    "plt.title('Base certificate')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify base cert dimensions and receptive fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_cert_grid.shape)\n",
    "\n",
    "dim_labels = ['attr_add', 'attr_del']\n",
    "\n",
    "n_layers = 2\n",
    "# List of binary sparse matrices\n",
    "receptive_field_masks = [gcn_receptive_field_mask(adj, dim_label, n_layers, n_nodes)\n",
    "                         for dim_label in dim_labels]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute additions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Istalled solvers: {cp.installed_solvers()}')\n",
    "\n",
    "solver = 'MOSEK'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collective_grid_add = collective_certificate_grid(\n",
    "    base_cert_grid, dim_labels, receptive_field_masks, attr, adj, max_rad=[0, 9, 0],\n",
    "    solver=solver\n",
    ")[0, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "pal = sns.color_palette('colorblind', 2)\n",
    "plt.plot(collective_grid_add / n_nodes, color=pal[0], label='Proposed collective certificate')\n",
    "plt.plot(base_cert_grid.mean(axis=0)[:, 0], color=pal[1], label='Naïve collective certificate')\n",
    "plt.legend()\n",
    "plt.xlabel('Attribute additions')\n",
    "plt.ylabel('Certified ratio')\n",
    "plt.xlim(0, 9)\n",
    "plt.ylim(0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute deletions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collective_grid_del = collective_certificate_grid(\n",
    "    base_cert_grid, dim_labels, receptive_field_masks, attr, adj, max_rad=[0, 0, 32],\n",
    "    solver=solver\n",
    ")[0, 0, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "pal = sns.color_palette('colorblind', 2)\n",
    "plt.plot(collective_grid_del / n_nodes, color=pal[0], label='Proposed collective certificate')\n",
    "plt.plot(base_cert_grid.mean(axis=0)[0, :], color=pal[1], label='Naïve collective certificate')\n",
    "plt.legend()\n",
    "plt.xlabel('Attribute deletions')\n",
    "plt.ylabel('Certified ratio')\n",
    "plt.xlim(0, 32)\n",
    "plt.ylim(0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Local constraints and limited attacker-controlled nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "collective_grid_del_local = collective_certificate_grid(\n",
    "    base_cert_grid, dim_labels, receptive_field_masks, attr, adj, max_rad=[0, 0, 32],\n",
    "    local_budget_descriptor=np.array([0, 0, 4]), local_budget_mode='absolute', num_attackers=6,\n",
    "    solver=solver\n",
    ")[0, 0, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set()\n",
    "pal = sns.color_palette('colorblind', 2)\n",
    "plt.plot(collective_grid_del_local / n_nodes, color=pal[0], label='Collective certificate w/ local constraints')\n",
    "plt.plot(base_cert_grid.mean(axis=0)[0, :], color=pal[1], label='Naïve collective certificate')\n",
    "plt.legend()\n",
    "plt.xlabel('Attribute deletions')\n",
    "plt.ylabel('Certified ratio')\n",
    "plt.xlim(0, 32)\n",
    "plt.ylim(0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
