{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "\n",
    "from init_graph import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import seaborn\n",
    "import seaborn as sns\n",
    "\n",
    "# Apply the default theme\n",
    "sns.set_theme()\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"]=100\n",
    "plt.rcParams['savefig.dpi']=300\n",
    "\n",
    "# Set a modern style for the plot\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "sns.set_style({\"axes.facecolor\": \".98\"})\n",
    "colors = sns.color_palette(\"muted\", 7)\n",
    "labels = [\"$FairRARI$\", \"PR+proj\", \"Inbetween\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logs_source_path_FairRARI = \"logs/FairRARI/\"\n",
    "logs_source_path_post_processing = \"logs/post_processing/\"\n",
    "\n",
    "save_source_path = \"figures/\"\n",
    "save_ = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'polbooks'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_path = \"datasets/\"\n",
    "\n",
    "G, protected_nodes, blue_nodes, red_nodes = init_graph(dataset_name, source_path)\n",
    "unprotected_nodes = blue_nodes\n",
    "\n",
    "\n",
    "n = G.number_of_nodes()\n",
    "m = G.number_of_edges()\n",
    "print('Number of Nodes:', n)\n",
    "print('Number of Edges:', m)\n",
    "\n",
    "# Create Protected and Un-Protected Set Vectors\n",
    "S_p = torch.zeros(n).int()\n",
    "S_p[protected_nodes] = 1\n",
    "S_up = torch.ones(n).int()\n",
    "S_up[protected_nodes] = 0\n",
    "n_p = len(protected_nodes)\n",
    "print('Number of Protected Nodes:', n_p)\n",
    "print('Number of UnProtected Nodes:', len(unprotected_nodes))\n",
    "\n",
    "print('Number of Protected Nodes over Number of Nodes: {:.6f}'.format(n_p/n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma = 0.15\n",
    "max_iters_eucl_iters = 1000\n",
    "max_iters_eucl_iters_1proj = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phi = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_path_FairRARI = logs_source_path_FairRARI+dataset_name+'/'\n",
    "\n",
    "load_path_FairRARI = dir_path_FairRARI+dataset_name+'_phi'+'{:.2f}'.format(phi)+'_gamma'+'{:.2f}'.format(gamma)\n",
    "load_path_FairRARI = load_path_FairRARI + '_iters'+str(max_iters_eucl_iters)+'_log.npy'\n",
    "\n",
    "variables_dict_FairRARI = np.load(load_path_FairRARI, allow_pickle=True)\n",
    "\n",
    "FairRARI_scores = variables_dict_FairRARI.item().get('x_opt')\n",
    "opr_scores = variables_dict_FairRARI.item().get('opr_scores')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_path_post_processing = logs_source_path_post_processing+dataset_name+'/'\n",
    "\n",
    "load_path_post_processing = dir_path_post_processing+dataset_name+'_phi'+'{:.2f}'.format(phi)+'_gamma'+'{:.2f}'.format(gamma)\n",
    "load_path_post_processing = load_path_post_processing + '_iters'+str(max_iters_eucl_iters_1proj)+'_log.npy'\n",
    "\n",
    "variables_dict_post_processing = np.load(load_path_post_processing, allow_pickle=True)\n",
    "\n",
    "post_processing_scores = variables_dict_post_processing.item().get('x_opt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos = nx.spring_layout(G, seed=0)  # positions for all nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_pr_nodes = np.where(FairRARI_scores == 0)[0].tolist()\n",
    "\n",
    "# Define node colors based on the protected/unprotected sets\n",
    "node_colors = ['C3' if node in protected_nodes else 'C0' for node in G.nodes()]\n",
    "node_colors = ['k' if FairRARI_scores[node]==0 else node_colors[node] for node in G.nodes()]\n",
    "unprotected_nodes = blue_nodes\n",
    "\n",
    "# Draw the network\n",
    "# You can adjust the layout, but here I am using a spring layout\n",
    "edge_colors = []\n",
    "for u, v in G.edges():\n",
    "    if u in protected_nodes and v in protected_nodes:\n",
    "        edge_colors.append('C3')  # Same class (protected)\n",
    "    elif u in unprotected_nodes and v in unprotected_nodes:\n",
    "        edge_colors.append('C0')  # Same class (unprotected)\n",
    "    else:\n",
    "        edge_colors.append('C4')  # Different classes (combination of blue and red)\n",
    "\n",
    "# Define node sizes based on PageRank scores (scaled for better visibility)\n",
    "node_sizes = [FairRARI_scores[node] * 1e4 for node in G.nodes()]  # Scaling for better visibility\n",
    "# node_sizes = [0.001e4 if pagerank_scores_fairrari[node]==0 else node_sizes[node] for node in G.nodes()]\n",
    "# Define dynamic font sizes based on node sizes\n",
    "max_node_size = max(node_sizes)\n",
    "font_sizes = [size / max_node_size * 15 + 5 for size in node_sizes]  # Scale and offset for visibility\n",
    "\n",
    "# Create a plot\n",
    "# plt.figure(figsize=(10, 8))\n",
    "plt.figure(figsize=(6, 4.8))\n",
    "\n",
    "# Draw the network\n",
    "nx.draw(\n",
    "    G,\n",
    "    pos,\n",
    "    with_labels=False,\n",
    "    # node_shape=node_shapes,\n",
    "    node_color=node_colors,  # Node color based on class\n",
    "    node_size=node_sizes,    # Node size based on PageRank\n",
    "    edge_color=edge_colors,  # Edge colors based on connected nodes' classes\n",
    "    alpha=0.85               # Transparency for better visibility\n",
    ")\n",
    "\n",
    "nx.draw_networkx_nodes(G, pos, nodelist=zero_pr_nodes, node_color='none', edgecolors='k', node_shape='o', node_size=0.003e4)\n",
    "\n",
    "\n",
    "plt.title(\"FairRARI\")\n",
    "\n",
    "save_folder = save_source_path +'network/'+dataset_name+'/'\n",
    "save_path = save_folder+dataset_name+'_network_'+str(phi)\n",
    "save_path = save_path+'_FairRARI.pdf' \n",
    "if save_==1:\n",
    "    if not os.path.exists(save_folder):\n",
    "        os.makedirs(save_folder)\n",
    "    plt.savefig(save_path, format='pdf', bbox_inches='tight')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_pr_nodes = np.where(post_processing_scores == 0)[0].tolist()\n",
    "\n",
    "# Define node colors based on the protected/unprotected sets\n",
    "node_colors = ['C3' if node in protected_nodes else 'C0' for node in G.nodes()]\n",
    "node_colors = ['k' if post_processing_scores[node]==0 else node_colors[node] for node in G.nodes()]\n",
    "unprotected_nodes = blue_nodes\n",
    "\n",
    "# Draw the network\n",
    "# You can adjust the layout, but here I am using a spring layout\n",
    "edge_colors = []\n",
    "for u, v in G.edges():\n",
    "    if u in protected_nodes and v in protected_nodes:\n",
    "        edge_colors.append('C3')  # Same class (protected)\n",
    "    elif u in unprotected_nodes and v in unprotected_nodes:\n",
    "        edge_colors.append('C0')  # Same class (unprotected)\n",
    "    else:\n",
    "        edge_colors.append('C4')  # Different classes (combination of blue and red)\n",
    "\n",
    "# Define node sizes based on PageRank scores (scaled for better visibility)\n",
    "node_sizes = [post_processing_scores[node] * 1e4 for node in G.nodes()]  # Scaling for better visibility\n",
    "# Define dynamic font sizes based on node sizes\n",
    "max_node_size = max(node_sizes)\n",
    "font_sizes = [size / max_node_size * 15 + 5 for size in node_sizes]  # Scale and offset for visibility\n",
    "\n",
    "# Create a plot\n",
    "# plt.figure(figsize=(10, 8))\n",
    "plt.figure(figsize=(6, 4.8))\n",
    "\n",
    "# Draw the network\n",
    "nx.draw(\n",
    "    G,\n",
    "    pos,\n",
    "    with_labels=False,\n",
    "    # node_shape=node_shapes,\n",
    "    node_color=node_colors,  # Node color based on class\n",
    "    node_size=node_sizes,    # Node size based on PageRank\n",
    "    edge_color=edge_colors,  # Edge colors based on connected nodes' classes\n",
    "    alpha=0.85               # Transparency for better visibility\n",
    ")\n",
    "\n",
    "nx.draw_networkx_nodes(G, pos, nodelist=zero_pr_nodes, node_color='none', edgecolors='k', node_shape='o', node_size=0.003e4)\n",
    "\n",
    "\n",
    "plt.title(\"Post-Processing\")\n",
    "\n",
    "save_folder = save_source_path +'network/'+dataset_name+'/'\n",
    "save_path = save_folder+dataset_name+'_network_'+str(phi)\n",
    "save_path = save_path+'_post_processing.pdf' \n",
    "if save_==1:\n",
    "    if not os.path.exists(save_folder):\n",
    "        os.makedirs(save_folder)\n",
    "    plt.savefig(save_path, format='pdf', bbox_inches='tight')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Number of Nodes with 0 Score:\")\n",
    "print(\"FairRARI:        \", sum(FairRARI_scores==0).item())\n",
    "print(\"Post-Processing: \", sum(post_processing_scores==0).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define node colors based on the protected/unprotected sets\n",
    "node_colors = ['C3' if node in protected_nodes else 'C0' for node in G.nodes()]\n",
    "node_colors = ['k' if opr_scores[node]==0 else node_colors[node] for node in G.nodes()]\n",
    "unprotected_nodes = blue_nodes\n",
    "\n",
    "# Draw the network\n",
    "# You can adjust the layout, but here I am using a spring layout\n",
    "edge_colors = []\n",
    "for u, v in G.edges():\n",
    "    if u in protected_nodes and v in protected_nodes:\n",
    "        edge_colors.append('C3')  # Same class (protected)\n",
    "    elif u in unprotected_nodes and v in unprotected_nodes:\n",
    "        edge_colors.append('C0')  # Same class (unprotected)\n",
    "    else:\n",
    "        edge_colors.append('C4')  # Different classes (combination of blue and red)\n",
    "\n",
    "# Define node sizes based on PageRank scores (scaled for better visibility)\n",
    "node_sizes = [opr_scores[node] * 1e4 for node in G.nodes()]  # Scaling for better visibility\n",
    "# Define dynamic font sizes based on node sizes\n",
    "max_node_size = max(node_sizes)\n",
    "font_sizes = [size / max_node_size * 15 + 5 for size in node_sizes]  # Scale and offset for visibility\n",
    "\n",
    "# Create a plot\n",
    "# plt.figure(figsize=(10, 8))\n",
    "plt.figure(figsize=(6, 4.8))\n",
    "\n",
    "# Draw the network\n",
    "nx.draw(\n",
    "    G,\n",
    "    pos,\n",
    "    with_labels=False,\n",
    "    node_color=node_colors,  # Node color based on class\n",
    "    node_size=node_sizes,    # Node size based on PageRank\n",
    "    edge_color=edge_colors,  # Edge colors based on connected nodes' classes\n",
    "    alpha=0.85               # Transparency for better visibility\n",
    ")\n",
    "\n",
    "plt.title(\"Original PageRank\")\n",
    "\n",
    "save_folder = save_source_path +'network/'+dataset_name+'/'\n",
    "save_path = save_folder+dataset_name+'_network'\n",
    "save_path = save_path+'_OPR.pdf' \n",
    "if save_==1:\n",
    "    if not os.path.exists(save_folder):\n",
    "        os.makedirs(save_folder)\n",
    "    plt.savefig(save_path, format='pdf', bbox_inches='tight')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
