{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3dc97f04",
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mFailed to connect to the remote Jupyter Server 'http://localhost:8080/'. Verify the server is running and reachable. (Failed to connect to the remote Jupyter Server 'http://localhost:8080/'. Verify the server is running and reachable. (Forbidden).)."
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "# Add parent directory of the notebook to sys.path\n",
    "sys.path.append(os.path.abspath(\"..\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "799f34cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle as pkl\n",
    "from utils import *\n",
    "from utils_plot import *\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "from utils_simple_access import *\n",
    "import pandas as pd\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "\n",
    "# Run scan_L_A.py to generate the results for Fig2.ipynb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "abfdba15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run scan_L_A.py to generate the results for Fig2.ipynb\n",
    "\n",
    "# Create an empty DataFrame with the specified headers\n",
    "df = pd.DataFrame(columns=['L', 'A', 'NC1', 'margins', 'order'])\n",
    "\n",
    "\n",
    "result_files = glob.glob('../results/sweep_L_A/*')\n",
    "\n",
    "for file in tqdm(result_files):\n",
    "    with open(file, 'rb') as f:\n",
    "        data_dict = pkl.load(f)\n",
    "    L = data_dict['C'].L\n",
    "    if L ==0:\n",
    "        continue\n",
    "    A = data_dict['C'].max_move\n",
    "    if data_dict['accuracy_l'][-1] < 0.99:\n",
    "        print(f\"A={A}, L={L}, accuracy={data_dict['accuracy_l'][-1]}\")\n",
    "        plt.plot(data_dict['accuracy_l'])\n",
    "        plt.show()\n",
    "    action_taken = data_dict['action_taken']\n",
    "    data_dict['X'] = data_dict['X'][abs(action_taken) <= 1]\n",
    "    data_dict['y'] = data_dict['y'][abs(action_taken) <= 1]\n",
    "    data_dict['hidden_states'][-1] = data_dict['hidden_states'][-1][abs(action_taken) <= 1]\n",
    "    data_dict['loc_y'] = data_dict['loc_y'][abs(action_taken) <= 1]\n",
    "    NC1 = calc_NC1_from_data_dict(data_dict)\n",
    "    margins = multiclass_functional_margin_from_data_dict(data_dict, reducer=np.min)[0]\n",
    "    order = get_order(data_dict)\n",
    "\n",
    "    df = pd.concat([df, pd.DataFrame([{\n",
    "        'L': L,\n",
    "        'A': A,\n",
    "        'NC1': NC1,\n",
    "        'margins': margins,\n",
    "        'order': 1-order\n",
    "    }])], ignore_index=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "3faea6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set global matplotlib formatting for ICLR paper compatibility\n",
    "# Set global font to Times New Roman\n",
    "\n",
    "mpl.rcParams.update({\n",
    "    'font.size': 18,\n",
    "    'axes.labelsize': 18,\n",
    "    'axes.titlesize': 20,\n",
    "    'xtick.labelsize': 16,\n",
    "    'ytick.labelsize': 16,\n",
    "    'legend.fontsize': 16,\n",
    "    'figure.titlesize': 22,\n",
    "    'axes.linewidth': 1.2,\n",
    "    'lines.linewidth': 2.0,\n",
    "    'lines.markersize': 8,\n",
    "    'xtick.direction': 'in',\n",
    "    'ytick.direction': 'in',\n",
    "    'xtick.major.size': 6,\n",
    "    'ytick.major.size': 6,\n",
    "    'xtick.minor.size': 3,\n",
    "    'ytick.minor.size': 3,\n",
    "    'xtick.major.width': 1.2,\n",
    "    'ytick.major.width': 1.2,\n",
    "    'xtick.minor.width': 1.0,\n",
    "    'ytick.minor.width': 1.0,\n",
    "    'legend.frameon': False,\n",
    "    'figure.dpi': 100,\n",
    "    'savefig.dpi': 300,\n",
    "    'figure.figsize': (6, 4),\n",
    "    'pdf.fonttype': 42,  # TrueType fonts for compatibility\n",
    "    'ps.fonttype': 42,\n",
    "    'text.usetex': False,  # Set to True if you want LaTeX rendering and have it installed\n",
    "    'axes.spines.top': False,\n",
    "    'axes.spines.right': False,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Times New Roman', 'Times', 'DejaVu Serif', 'serif'],\n",
    "})\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set_context(\"paper\")\n",
    "sns.set_style(\"whitegrid\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "0beb3533",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure L and A are numeric\n",
    "df['L_num'] = pd.to_numeric(df['L'])+1\n",
    "df['A_num'] = pd.to_numeric(df['A'])\n",
    "\n",
    "# Pivot tables for each metric\n",
    "pivot_NC1 = df.pivot_table(index='A_num', columns='L_num', values='NC1', aggfunc='mean')\n",
    "pivot_margins = df.pivot_table(index='A_num', columns='L_num', values='margins', aggfunc='mean')\n",
    "# pivot_margins = pivot_margins.div(pivot_margins.iloc[:, 0], axis=0)\n",
    "pivot_order = df.pivot_table(index='A_num', columns='L_num', values='order', aggfunc='mean')\n",
    "\n",
    "\n",
    "# Use the global matplotlib styling set above\n",
    "fig = plt.figure(figsize=(18, 10))\n",
    "gs = GridSpec(2, 6, height_ratios=[1, 1], hspace=0.3, wspace=0.25)\n",
    "\n",
    "# --- First row: 2D PCA plots for two datasets ---\n",
    "\n",
    "# Load the two specific files\n",
    "file1 = '../results/sweep_L_A/data_max_move_1_L_9.pkl'\n",
    "file2 = '../results/sweep_L_A/data_max_move_10_L_9.pkl'\n",
    "file3 = '../results/sweep_L_A/data_max_move_19_L_9.pkl'\n",
    "\n",
    "def plot_pca_subplot(ax, file, title, cb=False):\n",
    "    with open(file, 'rb') as f:\n",
    "        data_dict = pkl.load(f)\n",
    "    h = data_dict['hidden_states'][-1].cpu().numpy()\n",
    "    loc_y = data_dict['loc_y']\n",
    "    action_taken = data_dict['action_taken']\n",
    "    cond = abs(action_taken) <= 1\n",
    "    # If loc_y is 2D, color by the first dimension\n",
    "    color = loc_y[:, 0] if loc_y.ndim > 1 else loc_y\n",
    "    color = color[cond]\n",
    "    h = h[cond]\n",
    "    # PCA to 2D\n",
    "    from sklearn.decomposition import PCA\n",
    "    h_pca = PCA(n_components=2).fit_transform(h)\n",
    "    sc = ax.scatter(\n",
    "        h_pca[:, 0], h_pca[:, 1], c=color, cmap='coolwarm',\n",
    "        s=150, alpha=1, edgecolor='none'\n",
    "    )\n",
    "    ax.set_title(title, fontsize=20, pad=10)\n",
    "    ax.axis('equal')\n",
    "    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)\n",
    "    for spine in ['top', 'right', 'left', 'bottom']:\n",
    "        ax.spines[spine].set_visible(False)\n",
    "    if cb:\n",
    "        cbar = plt.colorbar(sc, ax=ax, pad=0.01, fraction=0.05)\n",
    "        cbar.ax.set_yticklabels([])  # Remove colorbar ticks\n",
    "        cbar.set_label('location', fontsize=16)\n",
    "\n",
    "ax_pca1 = fig.add_subplot(gs[0, 0:3])\n",
    "plot_pca_subplot(ax_pca1, file1, r'$A=1$')\n",
    "\n",
    "ax_pca2 = fig.add_subplot(gs[0, 3:6])\n",
    "plot_pca_subplot(ax_pca2, file2, r'$A=\\frac{1}{2}S$', cb=True)\n",
    "\n",
    "\n",
    "# --- Second row: original 3 heatmaps ---\n",
    "\n",
    "axs = [fig.add_subplot(gs[1, i*2:i*2+2]) for i in range(3)]\n",
    "\n",
    "\n",
    "cmap = 'plasma'\n",
    "# For NC1: log10, but avoid log(0) by setting a small floor\n",
    "pivot_NC1_log = np.log10(np.clip(pivot_NC1, 1e-8, None))\n",
    "sns.heatmap(\n",
    "    pivot_NC1_log, ax=axs[0], cmap=cmap, annot=False, fmt=\".2f\", vmax=0.1,\n",
    "    cbar_kws={'label': r'$\\log_{10}(\\mathrm{NC}_1)$'}\n",
    ")\n",
    "axs[0].set_title(r'$\\log_{10}(\\mathrm{NC}_1)$', fontsize=20, pad=10)\n",
    "axs[0].set_xlabel('L', fontsize=18)\n",
    "axs[0].set_ylabel('A', fontsize=18)\n",
    "axs[0].tick_params(axis='both', which='major', labelsize=16)\n",
    "S = axs[0].get_yticks()[-1]  # Assuming S is the max y value\n",
    "axs[0].set_yticks([1, 0.5*S, S])\n",
    "axs[0].set_yticklabels(['1', r'$\\frac{1}{2}S$', r'$S$'])\n",
    "\n",
    "# For margins: log10, but avoid log(0) by setting a small floor\n",
    "pivot_margins_log = np.log10(np.clip(pivot_margins, 1e-8, None))\n",
    "sns.heatmap(\n",
    "    pivot_margins_log, ax=axs[1], cmap=cmap, annot=False, fmt=\".2f\", vmax=0.03,\n",
    "    cbar_kws={'label': r'$\\log_{10}(\\mathrm{Margin})$'}\n",
    ")\n",
    "axs[1].set_title(r'$\\log_{10}(\\mathrm{Margin})$', fontsize=20, pad=10)\n",
    "axs[1].set_xlabel('L', fontsize=18)\n",
    "axs[1].tick_params(axis='x', which='major', labelsize=16)\n",
    "axs[1].set_yticks([])\n",
    "# Set y-ticks: first at 1, middle at 0.5S, last at S\n",
    "\n",
    "pivot_order_log = np.log10(np.clip(pivot_order, 1e-8, None))\n",
    "sns.heatmap(\n",
    "    pivot_order_log, ax=axs[2], cmap=cmap, annot=False, fmt=\".2f\", vmax=0.1,\n",
    "    cbar_kws={'label': r'$\\log_{10}(\\mathrm{Order})$'}\n",
    ")\n",
    "axs[2].set_title('1-Order', fontsize=20, pad=10)\n",
    "axs[2].set_xlabel('L', fontsize=18)\n",
    "axs[2].tick_params(axis='x', which='major', labelsize=16)\n",
    "axs[2].set_yticks([])\n",
    "\n",
    "axs[1].set_ylabel('')\n",
    "axs[2].set_ylabel('')\n",
    "plt.tight_layout()\n",
    "os.makedirs('./final_results/figures', exist_ok=True)\n",
    "fig.savefig('./final_results/figures/fig2.png', dpi=300)\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
