{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import networkx as nx\n",
    "import scipy\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = [15, 10]\n",
    "sns.set_style('whitegrid')\n",
    "plt.rcParams['font.size'] = 28.0\n",
    "plt.rcParams['xtick.labelsize'] = 28.0\n",
    "plt.rcParams['ytick.labelsize'] = 28.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphons.bionetworks_utils as bu \n",
    "from graphons.utils import truncate_sparse_matrix\n",
    "from graphons.graphon_experiments import get_boolean_sample\n",
    "import graphons.graphon_est_ell2 as ge\n",
    "from graphons import graphon_families as gf\n",
    "from graphons import utils\n",
    "from ell2_run_all import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook\n",
    "\n",
    "This notebook contains the code to make all visualizations.\n",
    "\n",
    "1. Visualizations from the real world experiments\n",
    "\n",
    "2. Getting results from the simulations on MMSB, graphon, etc. \n",
    "\n",
    "3. Making the side by side heatmaps for select families of graphons. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 Real World Experimentss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## setup functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results_dfs(results_df_list, titles_list, \n",
    "    quantile_min = 0.01, \n",
    "    quantile_max = 0.99,\n",
    "    p_flip_max = 0.6,\n",
    "    marker_size=24,\n",
    "    savepath=None):\n",
    "    num_exps = len(results_df_list)\n",
    "    fig, axs = plt.subplots(figsize=(20, 8), ncols=num_exps, sharey=True)\n",
    "    markers_list = ['D', 'o', 'P', 'X', 's']\n",
    "\n",
    "    np_fixed = results_df_list[0]['np'].unique()[0]\n",
    "    for idx in range(num_exps): \n",
    "        results_df = results_df_list[idx]\n",
    "        for idx2, p_flip in enumerate(results_df['p_flip'].unique()): \n",
    "            res_flip_subset = results_df[results_df['p_flip'] == p_flip].groupby('nq')\n",
    "            if idx2 == 0: \n",
    "                axs[idx].plot(results_df['nq'].unique(), \n",
    "                    res_flip_subset['our_algo'].median(),\n",
    "                    label='Our Algorithm',\n",
    "                    markersize=marker_size,\n",
    "                    marker=markers_list[idx2]\n",
    "                )\n",
    "                axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                    y1=res_flip_subset['our_algo'].quantile(quantile_min),\n",
    "                    y2=res_flip_subset['our_algo'].quantile(quantile_max), \n",
    "                    alpha=0.2, \n",
    "                )\n",
    "\n",
    "                axs[idx].plot(results_df['nq'].unique(), \n",
    "                    res_flip_subset['sbm_algo'].median(),\n",
    "                    label='SBM Algorithm',\n",
    "                    marker=markers_list[idx2+1],\n",
    "                    markersize=marker_size\n",
    "                )\n",
    "                axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                    y1=res_flip_subset['sbm_algo'].quantile(quantile_min),\n",
    "                    y2=res_flip_subset['sbm_algo'].quantile(quantile_max), \n",
    "                    alpha=0.2, \n",
    "                )\n",
    "\n",
    "            pflip=float(p_flip)\n",
    "            if pflip <= p_flip_max:\n",
    "\n",
    "                axs[idx].plot(results_df['nq'].unique(), \n",
    "                    res_flip_subset['bitflip_algo'].median(),\n",
    "                    label=f'Oracle, $p={pflip}$',\n",
    "                    marker = markers_list[idx2+2],\n",
    "                    markersize=marker_size\n",
    "                )\n",
    "                axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                    y1=res_flip_subset['bitflip_algo'].quantile(quantile_min),\n",
    "                    y2=res_flip_subset['bitflip_algo'].quantile(quantile_max), \n",
    "                    alpha=0.2, \n",
    "                )\n",
    "            axs[idx].set_yscale('log')\n",
    "        if titles_list is not None: \n",
    "            axs[idx].set_title(titles_list[idx])\n",
    "        # means = res_subset.groupby(['nq', 'p_flip']).mean()['bitflip_algo']\n",
    "        # mean_subset = means[means['p_flip'] == p_flip]\n",
    "        # print(p_flip, mean_subset)\n",
    "\n",
    "    axs[0].set_ylabel('MSE (Log Scale)')\n",
    "    axs[0].set_xlabel(f'$n_Q$ Value, with $n = {int(np_fixed)}$')\n",
    "    axs[1].set_xlabel(f'$n_Q$ Value, with $n = {int(np_fixed)}$')\n",
    "    # axs[0].set_yscale('log')\n",
    "    # axs[1].set_yscale('log')\n",
    "    handles, labels = axs[0].get_legend_handles_labels()\n",
    "    plt.legend(handles, labels, loc='upper center', \n",
    "               shadow=True,\n",
    "               ncol=3, bbox_to_anchor=(-0.2, 1.25))\n",
    "    plt.gcf().subplots_adjust(top=0.8)\n",
    "    plt.gcf().subplots_adjust(bottom=0.1)\n",
    "\n",
    "    # plt.tight_layout()\n",
    "    if savepath is not None: \n",
    "        plt.savefig(savepath, dpi=700.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## metabolic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metabolic_results = pd.read_csv('exp-results/results_3_methods/metabolic_seed_91_numtrials_50.csv').drop(columns=['Unnamed: 0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metabolic_results['p_label'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = metabolic_results[metabolic_results['p_label'] == 'iWFL_1372'].copy()\n",
    "df2 = metabolic_results[metabolic_results['p_label'] == 'iPC815'].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results_dfs(\n",
    "    results_df_list=[df1, df2], \n",
    "    titles_list=None, \n",
    "    savepath='figs-ell2/results_3_methods/metabolic_iWFL_1372_left_iPC815_right_target_iJN1463.pdf', \n",
    "    quantile_min=0.01,  \n",
    "    quantile_max=0.99, \n",
    "    marker_size=16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## email"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "email_results = pd.read_csv('exp-results/results_3_methods/email_eu_seed_91_numtrials_50_smaller_p_flips.csv');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "email_results['p_label'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "email_results['q_label'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1_email = email_results[email_results['q_label'] == 't1'].copy()\n",
    "df2_email = email_results[email_results['q_label'] == 't4'].copy()\n",
    "df3_email = email_results[email_results['q_label'] == 't7'].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results_dfs(\n",
    "    results_df_list=[df1_email, df3_email], \n",
    "    titles_list=None, \n",
    "    quantile_min=0.01,  \n",
    "    quantile_max=0.99, \n",
    "    p_flip_max=0.06, \n",
    "    marker_size=16,\n",
    "    savepath='figs-ell2/results_3_methods/email_source_t0_targets_t1_t7_smaller_p.pdf'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## plot side by side"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors_list = sns.color_palette(n_colors=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def plot_results_dfs(results_df_list, titles_list, \n",
    "#     quantile_min = 0.01, \n",
    "#     quantile_max = 0.99,\n",
    "#     p_flip_max = 0.6,\n",
    "#     marker_size=24,\n",
    "#     savepath=None):\n",
    "sns.set_palette(\"tab10\")\n",
    "results_df_list = [\n",
    "    df1, df2, \n",
    "    df1_email, df3_email\n",
    "]\n",
    "num_exps = 4\n",
    "marker_size = 16\n",
    "fig, axs = plt.subplots(figsize=(32, 8), ncols=num_exps, sharey=False)\n",
    "quantile_min = 0.01\n",
    "quantile_max = 0.99 \n",
    "\n",
    "\n",
    "\n",
    "# def add_frame(ax, color='red', linewidth=2):\n",
    "#     # Create a rectangle patch with the size of the axis\n",
    "#     rect = Rectangle((0, 0), 1, 1, transform=ax.transAxes,\n",
    "#                      color=color, linewidth=linewidth, fill=False, clip_on=False)\n",
    "#     ax.add_patch(rect)\n",
    "\n",
    "\n",
    "# cmap = plt.cm.viridis\n",
    "\n",
    "markers_list = ['D', 'o', 'P', 'X', 's']\n",
    "\n",
    "np_fixed_1 = results_df_list[0]['np'].unique()[0]\n",
    "np_fixed_2 = results_df_list[-1]['np'].unique()[0]\n",
    "\n",
    "for idx in [0, 1]: \n",
    "    results_df = results_df_list[idx]\n",
    "    for idx2, p_flip in enumerate(results_df['p_flip'].unique()): \n",
    "        res_flip_subset = results_df[results_df['p_flip'] == p_flip].groupby('nq')\n",
    "        if idx2 == 0: \n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['our_algo'].median(),\n",
    "                label='Alg. 1',\n",
    "                markersize=marker_size,\n",
    "                marker=markers_list[idx2],\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['our_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['our_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "            )\n",
    "\n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['sbm_algo'].median(),\n",
    "                label='Alg. 2',\n",
    "                marker=markers_list[idx2+1],\n",
    "                markersize=marker_size\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['sbm_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['sbm_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "            )\n",
    "\n",
    "        pflip=float(p_flip)\n",
    "        if pflip <= 0.6:\n",
    "\n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['bitflip_algo'].median(),\n",
    "                label=f'Oracle, $p={pflip}$',\n",
    "                marker = markers_list[idx2+2],\n",
    "                markersize=marker_size\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['bitflip_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['bitflip_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "            )\n",
    "        axs[idx].set_yscale('log')\n",
    "\n",
    "axs[0].set_ylabel('MSE (Log Scale)')\n",
    "axs[0].set_xlabel(f'$n_Q$, with $n = {int(np_fixed_1)}$')\n",
    "axs[1].set_xlabel(f'$n_Q$, with $n = {int(np_fixed_1)}$')\n",
    "\n",
    "markers_list = ['D', 'o', 6, 'H']\n",
    "for idx in [2, 3]:\n",
    "    results_df = results_df_list[idx]\n",
    "    for idx2, p_flip in enumerate(results_df['p_flip'].unique()): \n",
    "        res_flip_subset = results_df[results_df['p_flip'] == p_flip].groupby('nq')\n",
    "        if idx2 == 0: \n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['our_algo'].median(),\n",
    "                label='Alg. 1',\n",
    "                markersize=marker_size,\n",
    "                marker=markers_list[idx2]\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['our_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['our_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "            )\n",
    "\n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['sbm_algo'].median(),\n",
    "                label='Alg. 2',\n",
    "                marker=markers_list[idx2+1],\n",
    "                markersize=marker_size\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['sbm_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['sbm_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "            )\n",
    "\n",
    "        pflip=float(p_flip)\n",
    "        if pflip <= 0.06:\n",
    "            if p_flip >= 0.03: \n",
    "                c = colors_list[-2]\n",
    "            else: \n",
    "                c = colors_list[-1]\n",
    "            axs[idx].plot(results_df['nq'].unique(), \n",
    "                res_flip_subset['bitflip_algo'].median(),\n",
    "                label=f'Oracle, $p={pflip}$',\n",
    "                marker = markers_list[idx2+2],\n",
    "                markersize=marker_size, \n",
    "                color=c\n",
    "            )\n",
    "            axs[idx].fill_between(results_df['nq'].unique(), \n",
    "                y1=res_flip_subset['bitflip_algo'].quantile(quantile_min),\n",
    "                y2=res_flip_subset['bitflip_algo'].quantile(quantile_max), \n",
    "                alpha=0.2, \n",
    "                color=c\n",
    "            )\n",
    "        axs[idx].set_yscale('log')\n",
    "# axs[2].set_ylabel('MSE (Log Scale)')\n",
    "axs[2].set_xlabel(f'$n_Q$, with $n = {int(np_fixed_2)}$')\n",
    "axs[3].set_xlabel(f'$n_Q$, with $n = {int(np_fixed_2)}$')\n",
    "# axs[3].tick_params(labelleft=False)\n",
    "# axs[1].set(yticks=[])\n",
    "\n",
    "axs[0].sharey(axs[1])\n",
    "axs[2].sharey(axs[3])\n",
    "\n",
    "\n",
    "axs[0].sharex(axs[1])\n",
    "axs[2].sharex(axs[3])\n",
    "handles, labels = [], []\n",
    "for ax in axs:\n",
    "    for handle, label in zip(*ax.get_legend_handles_labels()):\n",
    "        if label not in labels:\n",
    "            handles.append(handle)\n",
    "            labels.append(label)\n",
    "\n",
    "fig.legend(handles, labels, loc='upper center', ncol=7, \n",
    "           fontsize=28,\n",
    "           bbox_to_anchor=(0.5, 1.15))\n",
    "\n",
    "plt.subplots_adjust(wspace=0.5)  # Increase horizontal space between subplots\n",
    "\n",
    "# rec = plt.Rectangle((ax[0] - 0.7, ax[2] - 0.2), \n",
    "#                     (ax[1] - ax[0]) + 1, (ax[3] - ax[2]) + 0.4, \n",
    "#                     fill=False, lw=2, linestyle=\"dotted\")\n",
    "\n",
    "# rec = sub.add_patch(rec)\n",
    "\n",
    "# rec.set_clip_on(False)\n",
    "\n",
    "\n",
    "# Function to add a frame around a specific axis\n",
    "from matplotlib.patches import Rectangle\n",
    "bbox0 = axs[0].get_position()\n",
    "bbox1 = axs[1].get_position()\n",
    "\n",
    "# Find the minimum and maximum extents of the combined bounding box\n",
    "x0 = min(bbox0.x0, bbox1.x0) - 0.00\n",
    "y0 = min(bbox0.y0, bbox1.y0) - 0.0\n",
    "x1 = max(bbox0.x1, bbox1.x1) + 0.0\n",
    "y1 = max(bbox0.y1, bbox1.y1) + 0.00\n",
    "\n",
    "# Add a rectangle patch that spans the combined bounding box\n",
    "rect = Rectangle((x0, y0), x1 - x0, y1 - y0, ls='--',\n",
    "                 transform=fig.transFigure, color='k', linewidth=2, fill=False, clip_on=False)\n",
    "# fig.patches.append(rect)\n",
    "fig.text(0.5*(bbox0.x0 + bbox1.x1), 1.08*bbox0.y1, 'Metabolic Networks', \n",
    "    ha='center', va='bottom', fontsize=40)\n",
    "\n",
    "bbox0 = axs[2].get_position()\n",
    "bbox1 = axs[3].get_position()\n",
    "\n",
    "# Find the minimum and maximum extents of the combined bounding box\n",
    "x0 = min(bbox0.x0, bbox1.x0) - 0.00\n",
    "y0 = min(bbox0.y0, bbox1.y0) - 0.0\n",
    "x1 = max(bbox0.x1, bbox1.x1) + 0.0\n",
    "y1 = max(bbox0.y1, bbox1.y1) + 0.00\n",
    "\n",
    "# Add a rectangle patch that spans the combined bounding box\n",
    "rect = Rectangle((x0, y0), x1 - x0, y1 - y0, ls='--',\n",
    "                 transform=fig.transFigure, color='k', linewidth=2, fill=False, clip_on=False)\n",
    "# fig.patches.append(rect)\n",
    "\n",
    "fig.text(0.5*(bbox0.x0 + bbox1.x1), 1.08*bbox0.y1,\n",
    "    'Email Networks', ha='center', va='bottom', fontsize=40)\n",
    "\n",
    "\n",
    "# plt.gcf().subplots_adjust(top=0.8)\n",
    "# plt.gcf().subplots_adjust(bottom=0.1)\n",
    "# axs[2].set_yticks([])\n",
    "# axs[3].set_yticks([])\n",
    "# axs[0].set_yscale('log')\n",
    "# axs[1].set_yscale('log')\n",
    "# handles, labels = axs[0].get_legend_handles_labels()\n",
    "# plt.legend(handles, labels, loc='upper center', \n",
    "#             shadow=True,\n",
    "#             ncol=3, bbox_to_anchor=(-0.2, 1.25))\n",
    "# plt.gcf().subplots_adjust(top=0.8)\n",
    "# plt.gcf().subplots_adjust(bottom=0.1)\n",
    "\n",
    "plt.tight_layout()\n",
    "savepath='figs-ell2/results_3_methods/real_world_side_by_side.pdf'\n",
    "if savepath is not None: \n",
    "    plt.savefig(savepath, dpi=700.0, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 Simulations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We output the means and stdevs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_results_df(res_df): \n",
    "    means = pd.DataFrame(res_df[res_df['p_flip'] == 0.1].groupby('nq')[['our_algo', 'sbm_algo', 'bitflip_algo']].mean())\n",
    "    means['bitflip_03'] = res_df[res_df['p_flip'] == 0.3].groupby('nq')[['bitflip_algo']].mean()\n",
    "    means['bitflip_05'] = res_df[res_df['p_flip'] == 0.5].groupby('nq')[['bitflip_algo']].mean()\n",
    "\n",
    "    stdevs = pd.DataFrame(2.0 * res_df[res_df['p_flip'] == 0.1].groupby('nq')[['our_algo', 'sbm_algo', 'bitflip_algo']].std())\n",
    "    stdevs['bitflip_03'] = 2.0 * res_df[res_df['p_flip'] == 0.3].groupby('nq')[['bitflip_algo']].std()\n",
    "    stdevs['bitflip_05'] = 2.0 * res_df[res_df['p_flip'] == 0.5].groupby('nq')[['bitflip_algo']].std()\n",
    "    return means, stdevs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_d_dim = pd.read_csv('exp-results/results_3_methods/d_dim_scale_p_25_q_10_seed_91_numtrials_50.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_graphon = pd.read_csv('exp-results/results_3_methods/graphon_alpha_01_beta_05_seed_91_numtrials_50.csv')\n",
    "res_graphon_wavy = pd.read_csv('exp-results/results_3_methods/graphon_liza_period_3_rotated_liza_period_3_ordinary_seed_91_numtrials_50.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_mmsb = pd.read_csv('exp-results/results_3_methods/mmsb_seed_91_numtrials_50.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_results_df(res_mmsb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_results_df(res_d_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_results_df(res_graphon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_results_df(res_graphon_wavy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## function that runs all three algos and returns results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## plotting fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmaps_with_source(out_d, plot_titles=False, cmap='Spectral', \n",
    "                                savepath=None, dpi_setting=100.0):\n",
    "    nq = list(out_d.keys())[0].split('_')[-1]\n",
    "    true_mat = np.triu(out_d[f'true_{nq}'], k = 1).copy().T\n",
    "    source_mat = np.triu(out_d[f'source_{nq}'], k = 1).copy()\n",
    "    ours_mat = np.triu(out_d[f'ours_{nq}'], k = 1).copy()\n",
    "    sbm_mat = np.triu(out_d[f'sbm_{nq}'], k = 1).copy()\n",
    "    oracle_10 = np.triu(out_d[f'oracle_10_{nq}'], k = 1).copy()\n",
    "    # oracle_30 = np.triu(out_d[f'oracle_30_{nq}'], k = 1).copy()\n",
    "    # oracle_50 = np.triu(out_d[f'oracle_50_{nq}'], k = 1).copy()\n",
    "\n",
    "    fig, axs = plt.subplots(ncols=5, \n",
    "                            gridspec_kw=dict(width_ratios=[2, 2, 2, 2, 0.2]), \n",
    "                            figsize=(15, 4))\n",
    "\n",
    "\n",
    "    leftmost = source_mat + true_mat\n",
    "    left = true_mat + ours_mat \n",
    "    middle = true_mat + sbm_mat \n",
    "    right = true_mat + oracle_10 \n",
    "\n",
    "    vmin_value = 0.0\n",
    "    vmax_value = 1.0\n",
    "\n",
    "    sns.heatmap(leftmost, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[0],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    sns.heatmap(left, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[1],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    sns.heatmap(middle, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[2],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    im = sns.heatmap(right, cmap=cmap, square=True, cbar=False,\n",
    "                    xticklabels=False, yticklabels=False, ax=axs[3],\n",
    "                    vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    fig.colorbar(axs[2].collections[0], cax=axs[-1])\n",
    "    plt.tight_layout()\n",
    "\n",
    "\n",
    "    if plot_titles: \n",
    "        axs[0].set_title('Source')\n",
    "        axs[1].set_title('Our Algorithm')\n",
    "        axs[2].set_title('Spectral')\n",
    "        axs[3].set_title('Oracle')\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if savepath:\n",
    "        plt.savefig(savepath, dpi=dpi_setting)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 Heatmaps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first time you run this notebook, you have to use reload_results = False. Afterwards it will save the results in a pickle file and you can use reload_results = True to load them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set this to be False if you want to recompute all the heatmap results \n",
    "# If True it will load the saved matrix outputs from storage. \n",
    "reload_results = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## plotting fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmaps_with_source(out_d, plot_titles=False, cmap='Spectral', savepath=None, dpi_setting=200.0):\n",
    "    nq = list(out_d.keys())[0].split('_')[-1]\n",
    "    true_mat = np.triu(out_d[f'true_{nq}'], k = 1).copy().T\n",
    "    source_mat = np.triu(out_d[f'source_{nq}'], k = 1).copy()\n",
    "    ours_mat = np.triu(out_d[f'ours_{nq}'], k = 1).copy()\n",
    "    sbm_mat = np.triu(out_d[f'sbm_{nq}'], k = 1).copy()\n",
    "    oracle_10 = np.triu(out_d[f'oracle_10_{nq}'], k = 1).copy()\n",
    "    # oracle_30 = np.triu(out_d[f'oracle_30_{nq}'], k = 1).copy()\n",
    "    # oracle_50 = np.triu(out_d[f'oracle_50_{nq}'], k = 1).copy()\n",
    "\n",
    "    fig, axs = plt.subplots(ncols=5, \n",
    "                            gridspec_kw=dict(width_ratios=[2, 2, 2, 2, 0.2]), \n",
    "                            figsize=(15, 4))\n",
    "\n",
    "\n",
    "    leftmost = source_mat + true_mat\n",
    "    left = true_mat + ours_mat \n",
    "    middle = true_mat + sbm_mat \n",
    "    right = true_mat + oracle_10 \n",
    "\n",
    "    vmin_value = 0.0\n",
    "    vmax_value = 1.0\n",
    "\n",
    "    sns.heatmap(leftmost, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[0],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    sns.heatmap(left, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[1],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    sns.heatmap(middle, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[2],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    im = sns.heatmap(right, cmap=cmap, square=True, cbar=False,\n",
    "                    xticklabels=False, yticklabels=False, ax=axs[3],\n",
    "                    vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    fig.colorbar(axs[2].collections[0], cax=axs[-1])\n",
    "    plt.tight_layout()\n",
    "\n",
    "\n",
    "    if plot_titles: \n",
    "        axs[0].set_title('Source')\n",
    "        axs[1].set_title('Our Algorithm')\n",
    "        axs[2].set_title('Spectral')\n",
    "        axs[3].set_title('Oracle')\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if savepath:\n",
    "        if 'eps' in savepath: \n",
    "            plt.savefig(savepath, format = 'eps', dpi=dpi_setting)\n",
    "        else: \n",
    "            plt.savefig(savepath, dpi=dpi_setting)\n",
    "        plt.close()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3a SBM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kp = 4\n",
    "kq = 2\n",
    "np_fixed = 2000 \n",
    "nq = 500\n",
    "pp = 0.8\n",
    "pq = 0.2\n",
    "qp = 0.9 \n",
    "qq = 0.1 \n",
    "\n",
    "graphon_fn_alpha = lambda x, y: gf.graphon0(x, y, \n",
    "                                            k=kp, p=pp, q=pq)\n",
    "graphon_fn_beta = lambda x, y: gf.graphon0(x, y, \n",
    "                                           k=kq, p=qp, q=qq)\n",
    "# for trial_number in range(num_trials): \n",
    "# alpha = np.round(alpha, 2)\n",
    "# beta = np.round(beta_val, 2)\n",
    "if not reload_results:\n",
    "    np.random.seed(42)\n",
    "    \n",
    "    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, \n",
    "                            q_graphon_fn=graphon_fn_beta,\n",
    "                            n_p = np_fixed, \n",
    "                            n_q = nq)\n",
    "\n",
    "    p_sample = gp.P.g1_sample.toarray().astype(np.float64)\n",
    "    q_sample = gp.Q.g1_sample.toarray().astype(np.float64)\n",
    "    q_full_ground_truth = gp.Q_extended.g1.copy()\n",
    "    subset_indices = gp.subset_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "\n",
    "    np.random.seed(42)\n",
    "    q_hat_ours, q_hat_sbm, q_hat_bitflips_list = ge.est_q_three_methods(\n",
    "        p_sample, q_sample, \n",
    "        q_full_ground_truth, \n",
    "        subset_indices, \n",
    "        p_flip_list=[0.01], \n",
    "        kp=4, \n",
    "        kq=2, \n",
    "        h_quantile=None,\n",
    "    )\n",
    "\n",
    "    out_d_sbm = {\n",
    "            f'source_{nq}': gp.P.g1.copy(),\n",
    "            f'true_{nq}': q_full_ground_truth, \n",
    "            f'ours_{nq}': q_hat_ours, \n",
    "            f'sbm_{nq}': q_hat_sbm, \n",
    "            f'oracle_10_{nq}': q_hat_bitflips_list[0], \n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "\n",
    "    for k in out_d_sbm.keys(): \n",
    "        fpath = 'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm_' + k + '.pkl'\n",
    "        utils.pickle_dump(out_d_sbm[k].astype(np.float16), fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if reload_results:\n",
    "    out_d_sbm = {}\n",
    "\n",
    "    mats_to_load = [\n",
    "        '_source_500', \n",
    "        '_true_500',\n",
    "        '_ours_500',\n",
    "        '_sbm_500',\n",
    "        '_oracle_10_500', \n",
    "    ]\n",
    "    # fpaths = [f'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm{x}.pkl' for x in mats_to_load]\n",
    "    for x in mats_to_load: \n",
    "        fpath = f'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm{x}.pkl' \n",
    "        out_d_sbm[x[1:]] = utils.pickle_load(fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# savepath = 'figs-ell2/results_3_methods/sbm_transfer_side_by_side_lowdpi.png'\n",
    "# plot_heatmaps_with_source(\n",
    "#     out_d_sbm, \n",
    "#     cmap='Spectral',\n",
    "#     savepath=savepath,\n",
    "#     dpi_setting=700.0\n",
    "# )\n",
    "# # plt.savefig(savepath, format='eps', dpi=200.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3b Visualization on Wavy Graphon\n",
    "\n",
    "3 sources:\n",
    "\n",
    "a) flipped from 0 to 1 \n",
    "\n",
    "b) inverted around x = 0.5 \n",
    "\n",
    "c) SBM source "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3b(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np_fixed = 2000 \n",
    "nq = 500\n",
    "\n",
    "# for trial_number in range(num_trials): \n",
    "# alpha = np.round(alpha, 2)\n",
    "# beta = np.round(beta_val, 2)\n",
    "np.random.seed(42)\n",
    "\n",
    "if not reload_results: \n",
    "    graphon_fn_alpha = lambda x, y: gf.graphon_liza_rotated(x, y, \n",
    "                                                            period=3, bias=0.5, \n",
    "                                                            invert=True, phase_shift=False)\n",
    "    graphon_fn_beta = lambda x, y: gf.graphon_liza_2(x, y, period=3, \n",
    "                                                    bias=0.5, invert=False, \n",
    "                                                    phase_shift=False)\n",
    "    \n",
    "    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, \n",
    "                            q_graphon_fn=graphon_fn_beta,\n",
    "                            n_p = np_fixed, \n",
    "                            n_q = nq)\n",
    "\n",
    "    p_sample = gp.P.g1_sample.toarray().astype(np.float64)\n",
    "    q_sample = gp.Q.g1_sample.toarray().astype(np.float64)\n",
    "    q_full_ground_truth = gp.Q_extended.g1.copy()\n",
    "    subset_indices = gp.subset_indices\n",
    "\n",
    "\n",
    "    np.random.seed(42)\n",
    "    q_hat_ours, q_hat_sbm, q_hat_bitflips_list = ge.est_q_three_methods(\n",
    "        p_sample, q_sample, \n",
    "        q_full_ground_truth, \n",
    "        subset_indices, \n",
    "        p_flip_list=[0.01], \n",
    "        kp=4, \n",
    "        kq=2, \n",
    "        h_quantile=None,\n",
    "    )\n",
    "\n",
    "    out_d_wavy_1 = {\n",
    "            f'source_{nq}': gp.P.g1.copy(),\n",
    "            f'true_{nq}': q_full_ground_truth, \n",
    "            f'ours_{nq}': q_hat_ours, \n",
    "            f'sbm_{nq}': q_hat_sbm, \n",
    "            f'oracle_10_{nq}': q_hat_bitflips_list[0], \n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "    for k in out_d_wavy_1.keys(): \n",
    "        fpath = 'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy1_' + k + '.pkl'\n",
    "        utils.pickle_dump(out_d_wavy_1[k].astype(np.float16), fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if reload_results:\n",
    "    out_d_wavy_1 = {}\n",
    "    if reload_results: \n",
    "        mats_to_load = [\n",
    "            '_source_500', \n",
    "            '_true_500',\n",
    "            '_ours_500',\n",
    "            '_sbm_500',\n",
    "            '_oracle_10_500', \n",
    "        ]\n",
    "        # fpaths = [f'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm{x}.pkl' for x in mats_to_load]\n",
    "        for x in mats_to_load: \n",
    "            fpath = f'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy1{x}.pkl' \n",
    "            out_d_wavy_1[x[1:]] = utils.pickle_load(fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# savepath = 'figs-ell2/results_3_methods/wavy_1_transfer_side_by_side.png'\n",
    "# plot_heatmaps_with_source(\n",
    "#     out_d_wavy_1, \n",
    "#     cmap='Spectral', \n",
    "#     savepath=savepath,\n",
    "#     dpi_setting=700.0\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3b(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np_fixed = 2000 \n",
    "nq = 500\n",
    "\n",
    "# for trial_number in range(num_trials): \n",
    "# alpha = np.round(alpha, 2)\n",
    "# beta = np.round(beta_val, 2)\n",
    "if not reload_results:\n",
    "    np.random.seed(42)\n",
    "\n",
    "    graphon_fn_alpha = lambda x, y: gf.graphon_liza_rotated(x, y, \n",
    "                                                            period=3, bias=0.5, \n",
    "                                                            invert=False, phase_shift=True)\n",
    "    graphon_fn_beta = lambda x, y: gf.graphon_liza_2(x, y, period=3, \n",
    "                                                    bias=0.5, invert=False, \n",
    "                                                    phase_shift=False)\n",
    "    \n",
    "    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, \n",
    "                            q_graphon_fn=graphon_fn_beta,\n",
    "                            n_p = np_fixed, \n",
    "                            n_q = nq)\n",
    "\n",
    "    p_sample = gp.P.g1_sample.toarray().astype(np.float64)\n",
    "    q_sample = gp.Q.g1_sample.toarray().astype(np.float64)\n",
    "    q_full_ground_truth = gp.Q_extended.g1.copy()\n",
    "    subset_indices = gp.subset_indices\n",
    "\n",
    "\n",
    "    np.random.seed(42)\n",
    "    q_hat_ours, q_hat_sbm, q_hat_bitflips_list = ge.est_q_three_methods(\n",
    "        p_sample, q_sample, \n",
    "        q_full_ground_truth, \n",
    "        subset_indices, \n",
    "        p_flip_list=[0.01], \n",
    "        kp=4, \n",
    "        kq=2, \n",
    "        h_quantile=None,\n",
    "    )\n",
    "\n",
    "    out_d = {\n",
    "            f'source_{nq}': gp.P.g1.copy(),\n",
    "            f'true_{nq}': q_full_ground_truth, \n",
    "            f'ours_{nq}': q_hat_ours, \n",
    "            f'sbm_{nq}': q_hat_sbm, \n",
    "            f'oracle_10_{nq}': q_hat_bitflips_list[0], \n",
    "        }\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "    for k in out_d.keys(): \n",
    "        fpath = 'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy2_' + k + '.pkl'\n",
    "        utils.pickle_dump(out_d[k].astype(np.float16), fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if reload_results:\n",
    "    out_d_wavy_2 = {}\n",
    "\n",
    "    mats_to_load = [\n",
    "        '_source_500', \n",
    "        '_true_500',\n",
    "        '_ours_500',\n",
    "        '_sbm_500',\n",
    "        '_oracle_10_500', \n",
    "    ]\n",
    "    # fpaths = [f'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm{x}.pkl' for x in mats_to_load]\n",
    "    for x in mats_to_load: \n",
    "        fpath = f'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy2{x}.pkl' \n",
    "        out_d_wavy_2[x[1:]] = utils.pickle_load(fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_heatmaps_with_source(\n",
    "#     out_d_wavy_2, \n",
    "#     cmap='Spectral', \n",
    "#     savepath='figs-ell2/results_3_methods/wavy_2_transfer_side_by_side.png',\n",
    "#     dpi_setting=700.0\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot the dicts stacked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "list_of_dicts = [out_d_sbm, out_d_wavy_1, out_d_wavy_2]\n",
    "cmap = 'Spectral'\n",
    "# oracle_30 = np.triu(out_d[f'oracle_30_{nq}'], k = 1).copy()\n",
    "# oracle_50 = np.triu(out_d[f'oracle_50_{nq}'], k = 1).copy()\n",
    "\n",
    "fig, axs = plt.subplots(ncols=5, nrows=3, \n",
    "                        gridspec_kw=dict(width_ratios=[2, 2, 2, 2, 0.2]), \n",
    "                        figsize=(15, 12))\n",
    "\n",
    "indx = 0 \n",
    "for out_d in list_of_dicts:\n",
    "    nq = list(out_d.keys())[0].split('_')[-1]\n",
    "    true_mat = np.triu(out_d[f'true_{nq}'], k = 1).copy().T\n",
    "    source_mat = np.triu(out_d[f'source_{nq}'], k = 1).copy()\n",
    "    ours_mat = np.triu(out_d[f'ours_{nq}'], k = 1).copy()\n",
    "    sbm_mat = np.triu(out_d[f'sbm_{nq}'], k = 1).copy()\n",
    "    oracle_10 = np.triu(out_d[f'oracle_10_{nq}'], k = 1).copy()\n",
    "\n",
    "    leftmost = source_mat + true_mat\n",
    "    left = true_mat + ours_mat \n",
    "    middle = true_mat + sbm_mat \n",
    "    right = true_mat + oracle_10 \n",
    "\n",
    "    vmin_value = 0.0\n",
    "    vmax_value = 1.0\n",
    "\n",
    "    sns.heatmap(leftmost, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[indx, 0],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    sns.heatmap(left, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[indx, 1],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    sns.heatmap(middle, cmap=cmap, square=True, \n",
    "                xticklabels=False, yticklabels=False, cbar=False, ax=axs[indx, 2],\n",
    "                vmin=vmin_value, vmax=vmax_value)\n",
    "    im = sns.heatmap(right, cmap=cmap, square=True, cbar=False,\n",
    "                    xticklabels=False, yticklabels=False, ax=axs[indx, 3],\n",
    "                    vmin=vmin_value, vmax=vmax_value)\n",
    "\n",
    "    fig.colorbar(axs[indx, 3].collections[0], cax=axs[indx, 4])\n",
    "    plt.tight_layout()\n",
    "\n",
    "\n",
    "    if indx==0: \n",
    "        axs[indx, 0].set_title('Source')\n",
    "        axs[indx, 1].set_title('Algorithm 1')\n",
    "        axs[indx, 2].set_title('Algorithm 2')\n",
    "        axs[indx, 3].set_title('Oracle ($p = 0.1$)')\n",
    "    indx += 1\n",
    "plt.tight_layout()\n",
    "\n",
    "savepath='figs-ell2/results_3_methods/transfer_heatmap_stacked_side_by_side.png'\n",
    "plt.savefig(savepath, dpi=700.0)\n",
    "# # Example: Populate the subplots with heatmaps\n",
    "# for i in range(3):\n",
    "#     for j in range(4):  # Only the first 4 columns will have heatmaps\n",
    "#         sns.heatmap(data, ax=axs[i, j], cbar=False, square=True, xticklabels=False, yticklabels=False)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# if savepath:\n",
    "#     if 'eps' in savepath: \n",
    "#         plt.savefig(savepath, format = 'eps', dpi=dpi_setting)\n",
    "#     else: \n",
    "#         plt.savefig(savepath, dpi=dpi_setting)\n",
    "#     plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Old Stuff \n",
    "\n",
    "You can ignore the stuff below this line. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3b(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np_fixed = 2000 \n",
    "nq = 500\n",
    "\n",
    "# for trial_number in range(num_trials): \n",
    "# alpha = np.round(alpha, 2)\n",
    "# beta = np.round(beta_val, 2)\n",
    "if not reload_results:\n",
    "    np.random.seed(42)\n",
    "\n",
    "    graphon_fn_alpha = lambda x, y: gf.graphon0(x, y, \n",
    "                                                k=3, p=0.9, q=1)\n",
    "    graphon_fn_beta = lambda x, y: gf.graphon_liza_2(x, y, period=3, \n",
    "                                                    bias=0.5, invert=False, \n",
    "                                                    phase_shift=False)\n",
    "    \n",
    "    gp = gf.GraphonPair(p_graphon_fn=graphon_fn_alpha, \n",
    "                            q_graphon_fn=graphon_fn_beta,\n",
    "                            n_p = np_fixed, \n",
    "                            n_q = nq)\n",
    "\n",
    "    p_sample = gp.P.g1_sample.toarray().astype(np.float64)\n",
    "    q_sample = gp.Q.g1_sample.toarray().astype(np.float64)\n",
    "    q_full_ground_truth = gp.Q_extended.g1.copy()\n",
    "    subset_indices = gp.subset_indices\n",
    "\n",
    "\n",
    "    np.random.seed(42)\n",
    "    q_hat_ours, q_hat_sbm, q_hat_bitflips_list = ge.est_q_three_methods(\n",
    "        p_sample, q_sample, \n",
    "        q_full_ground_truth, \n",
    "        subset_indices, \n",
    "        p_flip_list=[0.01], \n",
    "        kp=4, \n",
    "        kq=2, \n",
    "        h_quantile=None,\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "    out_d_wavy3 = {\n",
    "            f'source_{nq}': gp.P.g1.copy(),\n",
    "            f'true_{nq}': q_full_ground_truth, \n",
    "            f'ours_{nq}': q_hat_ours, \n",
    "            f'sbm_{nq}': q_hat_sbm, \n",
    "            f'oracle_10_{nq}': q_hat_bitflips_list[0], \n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not reload_results:\n",
    "    for k in out_d_wavy3.keys(): \n",
    "        fpath = 'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy3_' + k + '.pkl'\n",
    "        utils.pickle_dump(out_d_wavy3[k].astype(np.float16), fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if reload_results:\n",
    "    out_d_wavy_3 = {}\n",
    "\n",
    "    mats_to_load = [\n",
    "        '_source_500', \n",
    "        '_true_500',\n",
    "        '_ours_500',\n",
    "        '_sbm_500',\n",
    "        '_oracle_10_500', \n",
    "    ]\n",
    "    # fpaths = [f'exp-results/results_3_methods/mat_pickles_for_heatmap/sbm{x}.pkl' for x in mats_to_load]\n",
    "    for x in mats_to_load: \n",
    "        fpath = f'exp-results/results_3_methods/mat_pickles_for_heatmap/wavy3{x}.pkl' \n",
    "        out_d_wavy_3[x[1:]] = utils.pickle_load(fpath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_heatmaps_with_source(\n",
    "    out_d_wavy_3, \n",
    "    cmap='Spectral',\n",
    "    savepath='figs-ell2/results_3_methods/wavy_3_transfer_side_by_side.png',\n",
    "    dpi_setting=700.0\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4 Check Degrees of Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "email_data = load_email_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deg_list = []\n",
    "indx = 0\n",
    "for k in email_data.keys(): \n",
    "    adj = email_data[k]\n",
    "    print(f'Time Period {int(k)}: {np.mean(np.sum(adj, axis=1))}')\n",
    "    deg_list.append(np.median(np.sum(adj, axis=1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_nets = load_metabolic_networks()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "uq_species = meta_nets.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "species = ['iWFL_1372', 'iJN1463', 'iPC815']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "species_degs = {}\n",
    "for k in species: \n",
    "    mat = np.array(meta_nets[k].iloc[:, :-1])\n",
    "    degs = np.sum(mat, axis=1)\n",
    "    species_degs[k] = degs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in species: \n",
    "    print(k, np.median(species_degs[k]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
