{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Collection\n",
    "This file contains all code for generating graphs and tables."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "# Note: CNC_DeCtection is a typo...\n",
    "envs_rpomdp = [\"Toy\", \"Machine\", \"ChainInf\", \"Chain10\"]\n",
    "# envs_rpomdp = [\"Machine\", \"ChainInf\", \"Chain10\"]\n",
    "\n",
    "envs_pomdp = [\"Tiger\", \"HeavenOrHell5\", \"HeavenOrHell10\", \"MiniHallway\", \"Aloha10\"]\n",
    "# envs_pomdp = [\"HeavenOrHell5\", \"HeavenOrHell10\", \"RockSample5\", \"Aloha10\"]\n",
    "envs_rmdp = [\"Replacement\", \"CNC_Detection\"]\n",
    "\n",
    "# envs_all = envs_rpomdp + [name+\"_add_rel\" for name in envs_pomdp]\n",
    "envs_all = envs_rpomdp + envs_pomdp + envs_rmdp\n",
    "\n",
    "def get_latex_name_command(name):\n",
    "    if name.startswith(\"Toy\"):\n",
    "        return \"\\\\toy\"\n",
    "    elif name.startswith(\"ChainInf\"):\n",
    "        return \"$\\\\chain(\\\\infty)$\"\n",
    "    elif name.startswith(\"Chain\"):\n",
    "        nmbr = name[5:].partition(\"_\")[0]\n",
    "        return f\"\\\\chain({nmbr})\"\n",
    "    elif name.startswith(\"Machine\"):\n",
    "        return \"\\\\machine\"\n",
    "    elif name.startswith(\"CNC_Detection\"): #TODO: typo\n",
    "        return \"\\\\crc\"\n",
    "    elif name.startswith(\"Tiger\"):\n",
    "        return \"\\\\tiger\"\n",
    "    elif name.startswith(\"HeavenOrHell\"):\n",
    "        nmbr = name[12:].partition(\"_\")[0]\n",
    "        return f\"$\\\\heaven({nmbr})$\"\n",
    "    elif name.startswith(\"RockSample\"):\n",
    "        nmbr = name[10:].partition(\"_\")[0]\n",
    "    elif name.startswith(\"MiniHallway\"):\n",
    "        return \"\\\\minihall\"\n",
    "        return f\"\\\\rocksample~({nmbr})\"\n",
    "    elif name.startswith(\"Aloha\"):\n",
    "        nmbr = name[5:].partition(\"_\")[0]\n",
    "        return f\"$\\\\aloha({nmbr})$\"\n",
    "    elif name.startswith(\"Replacement\"):\n",
    "        return \"\\\\replacement\"\n",
    "    else:\n",
    "        print(f\"Error: environment {name} name not recognized\")\n",
    "        return name\n",
    "\n",
    "def get_latex_name(name):\n",
    "    if name.startswith(\"Toy\"):\n",
    "        return r\"$\\textsc{Toy}^*$\"\n",
    "    elif name.startswith(\"ChainInf\"):\n",
    "        return r\"\\textsc{Parity}($\\infty$)\"\n",
    "    elif name.startswith(\"Chain\"):\n",
    "        nmbr = name[5:].partition(\"_\")[0]\n",
    "        return r\"\\textsc{Parity}\"+f\"({nmbr})\"\n",
    "    elif name.startswith(\"Machine\"):\n",
    "        return r\"\\textsc{Machine}\"\n",
    "    elif name.startswith(\"CNC_Detection\"): #TODO: typo\n",
    "        return r\"\\textsc{HealthDetection}\"\n",
    "    elif name.startswith(\"Tiger\"):\n",
    "        return r\"\\textsc{Tiger}\"\n",
    "    elif name.startswith(\"HeavenOrHell\"):\n",
    "        nmbr = name[12:].partition(\"_\")[0]\n",
    "        return r\"\\textsc{HeavenOrHell}\"+f\"({nmbr})\"\n",
    "    elif name.startswith(\"RockSample\"):\n",
    "        nmbr = name[10:].partition(\"_\")[0]\n",
    "        return r\"\\textsc{RockSample}\"+f\"({nmbr})\"\n",
    "    elif name.startswith(\"MiniHallway\"):\n",
    "        return r\"\\textsc{MiniHallway}\"\n",
    "    elif name.startswith(\"Aloha\"):\n",
    "        nmbr = name[5:].partition(\"_\")[0]\n",
    "        return r\"\\textsc{Aloha}\"+f\"({nmbr})\"\n",
    "    elif name.startswith(\"Replacement\"):\n",
    "        return r\"\\textsc{Replacement}\"\n",
    "    else:\n",
    "        print(f\"Error: environment {name} name not recognized\")\n",
    "        return name\n",
    "\n",
    "\n",
    "all_rtypes = [\"full\",  \"mid\",\"maxent\", \"rmdp\" ]\n",
    "\n",
    "root = os.path.abspath('')\n",
    "path = \"Data/Tests\"\n",
    "prefix = \"RPolicyTest\"\n",
    "path_figs = \"Data/Figures\"\n",
    "\n",
    "\n",
    "def get_data(env, alg, rtype):\n",
    "    tests_folder = os.path.join(root, path)\n",
    "    matching_files = []\n",
    "    start_string = f\"{prefix}_{env}_{alg}_{rtype}\"\n",
    "    for filename in os.listdir(tests_folder):\n",
    "        print(filename)\n",
    "        if filename.startswith(start_string):\n",
    "                matching_files.append(filename)\n",
    "\n",
    "    if len(matching_files) == 0:\n",
    "        print(f\"ERROR: could not find the following file: {start_string}\")\n",
    "        return None\n",
    "    \n",
    "    matching_files = np.sort(matching_files)\n",
    "    filename = os.path.join(root, path, matching_files[-1])\n",
    "    with open(filename, \"r\") as file:\n",
    "        return json.load(file)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def value_to_string(nmbr:float, is_percentage=False):\n",
    "    \n",
    "    if is_percentage:\n",
    "        if nmbr < 0 and nmbr > -0.005:\n",
    "            nmbr = -nmbr\n",
    "        return f'{nmbr*100:.0f}\\%'\n",
    "    if abs(nmbr) < 10 and not is_percentage:\n",
    "        if nmbr < 0 and nmbr > -0.005:\n",
    "            nmbr = -nmbr\n",
    "        return f'{nmbr:.2f}'\n",
    "    elif abs(nmbr) < 100:\n",
    "        return f'{nmbr:.1f}'\n",
    "    else:\n",
    "        return f'{nmbr:.0f}'\n",
    "\n",
    "def getvalues(data):\n",
    "    value_sol, value_adv = float(data[\"value_sol\"]), float(data[\"value_adv\"])\n",
    "    value_diff =  (value_sol - value_adv) / abs(value_adv) \n",
    "    return value_sol, value_adv, value_diff\n",
    "    # return [value_to_string(value_sol), value_to_string(value_adv), value_to_string(value_diff)]\n",
    "\n",
    "\n",
    "# Data collection\n",
    "rows = []\n",
    "Data = np.zeros((len(envs_all), len(all_rtypes), 3, 3), dtype=float) - np.inf\n",
    "\n",
    "stats = {}\n",
    "\n",
    "stds = {}\n",
    "\n",
    "for (envidx, env) in enumerate(envs_all):\n",
    "    rows.append(get_latex_name(env))# + \" &\" )\n",
    "    if get_latex_name_command(env) not in stds:\n",
    "        stds[get_latex_name_command(env)] = {}\n",
    "    # rows.append('\\\\mcrot{1}{1}{60}{'+get_latex_name(env)+\"}\" )\n",
    "    for (rtypeidx, rtype) in enumerate(all_rtypes):\n",
    "        data = get_data(env, \"RHSVI\", rtype)\n",
    "        val_sol, val_adv, val_diff = getvalues(data)\n",
    "        # print(val_sol)\n",
    "        if  not get_latex_name_command(env) in stats:\n",
    "            stats[get_latex_name_command(env)] = {\n",
    "                '$|S|$' : int(data['states']),\n",
    "                '$|\\Omega|$' : int(data['observations']),\n",
    "                '$|A|$' : int(data['actions'])\n",
    "            }\n",
    "        \n",
    "        # stds[get_latex_name_command(env)][f'RHSVI({rtype})'] = f\"${val_adv:.2f} + {float(data['std_adv']):.2f}$\" if data['std_adv'] is not None else f\"${val_adv:.2f} + ?$\" \n",
    "        \n",
    "        stds[get_latex_name_command(env)][f'RHSVI({rtype})(min)'] = f\"${val_adv:.2f}$\"\n",
    "        stds[get_latex_name_command(env)][f'RHSVI({rtype})(std)'] = f\"${float(data['std_adv']):.2f}$\"\n",
    "\n",
    "        \n",
    "        Data[envidx, rtypeidx, 0, 0] = val_sol\n",
    "        Data[envidx, rtypeidx, 0, 1] = val_adv\n",
    "        Data[envidx, rtypeidx, 0, 2] = val_diff\n",
    "        \n",
    "        \n",
    "\n",
    "        if rtype == \"full\":\n",
    "            data = get_data(env, \"RQMDP\", rtype)\n",
    "            val_sol, val_adv, val_diff = getvalues(data)\n",
    "            # print(val_adv)\n",
    "            Data[envidx, rtypeidx, 1, 0] = val_sol\n",
    "            Data[envidx, rtypeidx, 1, 1] = val_adv\n",
    "            Data[envidx, rtypeidx, 1, 2] = val_diff\n",
    "\n",
    "            # stds[get_latex_name_command(env)]['RQMDP'] = f\"${val_adv:.2f} + {float(data['std_adv']):.2f}$\" if data['std_adv'] is not None else f\"${val_adv:.2f} + ?$\"\n",
    "            \n",
    "            stds[get_latex_name_command(env)][f'RQMDP(min)'] = f\"${val_adv:.2f}$\"\n",
    "            stds[get_latex_name_command(env)][f'RQMDP(std)'] = f\"${float(data['std_adv']):.2f}$\"\n",
    "\n",
    "            data = get_data(env, \"RFIB\", rtype)\n",
    "            val_sol, val_adv, val_diff = getvalues(data)\n",
    "            # print(val_sol)\n",
    "            Data[envidx, rtypeidx, 2, 0] = val_sol\n",
    "            Data[envidx, rtypeidx, 2, 1] = val_adv\n",
    "            Data[envidx, rtypeidx, 2, 2] = val_diff\n",
    "            \n",
    "            # stds[get_latex_name_command(env)]['RFIB'] = f\"${val_adv:.2f} + {float(data['std_adv']):.2f}$\" if data['std_adv'] is not None else f\"${val_adv:.2f} + ?$\" \n",
    "            \n",
    "            stds[get_latex_name_command(env)][f'RFIB(min)'] = f\"${val_adv:.2f}$\"\n",
    "            stds[get_latex_name_command(env)][f'RFIB(std)'] = f\"${float(data['std_adv']):.2f}$\"\n",
    "\n",
    "\n",
    "# Set to strings\n",
    "Data_str = np.empty_like(Data, dtype=object)\n",
    "\n",
    "for idx, val in np.ndenumerate(Data):\n",
    "    Data_str[idx] = value_to_string(val)\n",
    "\n",
    "# Bold largest eval value gap\n",
    "for envidx in range(len(envs_all)):\n",
    "    # Stack everything across rtypes and the third axis\n",
    "    values = Data[envidx, :, :, 1]  # Shape (num_rtypes, num_something)\n",
    "    \n",
    "    # Flatten into 1D array for easy minimization\n",
    "    flat_values = values.flatten()\n",
    "    min_value = np.max(flat_values)\n",
    "    # if min_value >= -0.005:\n",
    "    \n",
    "    # Find all (rtype, idx) pairs where the value equals min_value\n",
    "    rtype_indices, dim_indices = np.where(values >= min_value - 0.01)\n",
    "\n",
    "    for rtypeidx, idx in zip(rtype_indices, dim_indices):\n",
    "        # Add bold formatting to the corresponding string\n",
    "        Data_str[envidx, rtypeidx, idx, 1] = \"\\\\textbf{\" + Data_str[envidx, rtypeidx, idx, 1] + \"}\"\n",
    "\n",
    "# add &'s where necessary\n",
    "Data_str[:, 0, 0, 1] = np.char.add(Data_str[:, 0, 0, 1], \" &\")\n",
    "Data_str[:, 0, 2, 1] = np.char.add(Data_str[:, 0, 2, 1], \" &\")\n",
    "\n",
    "# Data = Data_str\n",
    "# Table 1: only full RHSVI, envs in header\n",
    "# Data = Data.reshape(len(envs_all), 6 * len(all_rtypes))\n",
    "rows = rows\n",
    "\n",
    "# Data1 = Data.reshape(len(envs_all), 9 * len(all_rtypes))\n",
    "# df = pd.DataFrame(Data1, index=rows, dtype=str)\n",
    "# print(df)\n",
    "# print(df.to_latex())\n",
    "\n",
    "\n",
    "# Table 2: all variants, solver in header\n",
    "\n",
    "# indexes: env, rtype, solver, valuetype\n",
    "\n",
    "best_naive = np.max([Data[:,1,0,1],Data[:,2,0,1],Data[:,3,0,1]], axis=0)\n",
    "\n",
    "Data_selected = np.array([  Data[:,0,0,0], Data[:,0,0,1],                   # Full robustness, Vpol & Veval\n",
    "                    # best_naive, \n",
    "                    # Data[:,0,1,1],                                  # Full robustness, RQMDP\n",
    "                    # Data[:,0,2,1],                                  # Full robustness, RFIB\n",
    "                    \n",
    "                    Data[:,1,0,1],\n",
    "                    Data[:,2,0,1],\n",
    "                    Data[:,3,0,1],     # mid, maxent, rmdp model, HSVI\n",
    "                    ]).transpose()  \n",
    "\n",
    "Data_selected_ub = np.array([  Data[:,0,0,0], Data[:,0,0,0],                   # Full robustness, Vpol & Veval\n",
    "                    # best_naive, \n",
    "                    # Data[:,0,1,1],                                  # Full robustness, RQMDP\n",
    "                    # Data[:,0,2,1],                                  # Full robustness, RFIB\n",
    "                    \n",
    "                    Data[:,1,0,0],\n",
    "                    Data[:,2,0,0],\n",
    "                    Data[:,3,0,0],     # mid, maxent, rmdp model, HSVI\n",
    "                    ]).transpose()   \n",
    "\n",
    "Data_selected_2 = np.array([  Data[:,0,0,0], Data[:,0,0,1],                   # Full robustness, Vpol & Veval\n",
    "                    # best_naive, \n",
    "                    Data[:,0,1,1],                                  # Full robustness, RQMDP\n",
    "                    Data[:,0,2,1],                                  # Full robustness, RFIB\n",
    "                    \n",
    "                    # Data[:,1,0,0],\n",
    "                    # Data[:,2,0,0],\n",
    "                    # Data[:,3,0,0],     # mid, maxent, rmdp model, HSVI\n",
    "                    ]).transpose() \n",
    "\n",
    "Data_selected_2_ub = np.array([  Data[:,0,0,0], Data[:,0,0,0],                   # Full robustness, Vpol & Veval\n",
    "                    # best_naive, \n",
    "                    Data[:,0,1,0],                                  # Full robustness, RQMDP\n",
    "                    Data[:,0,2,0],                                  # Full robustness, RFIB\n",
    "                    \n",
    "                    # Data[:,1,0,0],\n",
    "                    # Data[:,2,0,0],\n",
    "                    # Data[:,3,0,0],     # mid, maxent, rmdp model, HSVI\n",
    "                    ]).transpose() \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(stds).T.iloc[:, :6]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = pd.MultiIndex(levels=[['RHSVI', 'RQMDP', 'RFIB'], ['Min.', 'Std.']],\n",
    "            codes=[[0, 0, 1, 1, 2, 2],[0, 1]*3,],\n",
    "            names=['Algorithm', 'Metric'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns = idx\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df.style.to_latex(hrules=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2 = pd.DataFrame(stds).T.iloc[:, 6:]\n",
    "df2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx2 = pd.MultiIndex(levels=[['RHSVI($\\\\Mcenter$)', 'RHSVI($\\\\Ment$)', 'RHSVI($\\\\Mrmdp$)'], ['Min.', 'Std.']],\n",
    "            codes=[[0, 0, 1, 1, 2, 2],[0, 1]*3,],\n",
    "            names=['Algorithm', 'Metric'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df2.columns = idx2\n",
    "df2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df2.style.to_latex(hrules=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(pd.DataFrame(stats).T.style.to_latex(hrules=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Data[0,0,2,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Data_selected_2[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Data_selected_2_ub[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot data:\n",
    "\n",
    "colours =  ['#E26BBA', '#286d90', '#44b7f1',  '#a4dbf7']\n",
    "# colours = ['red', 'blue', 'green', 'orange']\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "D = Data_selected  # Shape: (N, M)\n",
    "D_ub = Data_selected_ub\n",
    "D2 = Data_selected_2\n",
    "D2_ub = Data_selected_2_ub\n",
    "N, M = len(envs_all), np.shape(D2)[1]\n",
    "rows.reverse()\n",
    "markers = [\"s\", \"D\", \"o\", \"X\"]\n",
    "\n",
    "max_abs_value = 0.8\n",
    "\n",
    "labels_naive = [r'$\\mathcal{M}$', r'$M_\\text{Center}$', r'$M_\\text{Ent}$', r'$M_\\text{RMDP}$']\n",
    "labels_heuristics = [r'RHSVI', r'RQMDP', r'RFIB']\n",
    "\n",
    "# Create subplots\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharey=True)\n",
    "\n",
    "for m in range(1, np.shape(D)[1]):\n",
    "    raw_vals = (D[:, m] - D[:, 0]) / abs(D[:, 0])\n",
    "    clipped_vals = np.clip(raw_vals, -max_abs_value, max_abs_value)\n",
    "    clipped_mask = np.abs(raw_vals) >= max_abs_value\n",
    "\n",
    "    y_vals = np.arange(N)[::-1] + 1/8 * (3 - m) - 1/16\n",
    "\n",
    "    # Unclipped points\n",
    "    ax1.scatter(clipped_vals[~clipped_mask], y_vals[~clipped_mask],\n",
    "                label=labels_naive[m - 1], color=colours[m - 1], alpha=1.0, marker=markers[m-1],\n",
    "                edgecolors='black', linewidth=0.2)\n",
    "\n",
    "    # Split mask for arrow direction\n",
    "    right_mask = clipped_mask & (raw_vals > 0)\n",
    "    left_mask = clipped_mask & (raw_vals < 0)\n",
    "\n",
    "    # Right-pointing arrows (positive outliers)\n",
    "    ax1.scatter(clipped_vals[right_mask] -0.02, y_vals[right_mask],\n",
    "                color=colours[m - 1], marker='>', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None, s=25.0)\n",
    "    ax1.scatter(clipped_vals[right_mask] - 0.04, y_vals[right_mask],\n",
    "                color=colours[m - 1], marker='_', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None)\n",
    "\n",
    "    # Left-pointing arrows (negative outliers)\n",
    "    ax1.scatter(clipped_vals[left_mask] +0.02, y_vals[left_mask],\n",
    "                color=colours[m - 1], marker='<', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None, s=25.0)\n",
    "    ax1.scatter(clipped_vals[left_mask] + 0.04, y_vals[left_mask],\n",
    "                color=colours[m - 1], marker='_', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax1.axvline(0, color='lightgray', linestyle='-', linewidth=1.0, zorder=0)\n",
    "gridlines_y = np.array(range(len(envs_all))) + 0.5\n",
    "for y in gridlines_y:\n",
    "    ax1.axhline(y, color='lightgray', linestyle='--', linewidth=0.6, zorder=0)\n",
    "\n",
    "bold_gridlines_y = len(envs_all) - np.array([4, 9]) - 0.5\n",
    "for y in bold_gridlines_y:\n",
    "    ax1.axhline(y, color='white', linestyle='-', linewidth=1.0, zorder=0)\n",
    "    ax1.axhline(y, color='black', linestyle='--', linewidth=1.0, zorder=0)\n",
    "\n",
    "ax1.set_xlabel(r'Relative value gap ($V_\\text{gap}$)')\n",
    "ax1.set_xlim(-max_abs_value, max_abs_value)\n",
    "ax1.set_yticks(np.arange(N)[::-1])\n",
    "ax1.set_yticklabels(rows[::-1])\n",
    "ax1.legend(framealpha=1.0, facecolor='white', edgecolor='black', loc='upper right')\n",
    "# ax1.set_title(\"(R)HSVI for different models\")\n",
    "\n",
    "\n",
    "# Subplot 2: Heuristics\n",
    "\n",
    "colours = ['#E26BBA', '#3f7a52', '#64c384']\n",
    "for m in range(1, np.shape(D2)[1]):\n",
    "    raw_vals = (D2[:, m] - D2[:, 0]) / abs(D2[:, 0])\n",
    "    clipped_vals = np.clip(raw_vals, -max_abs_value, max_abs_value)\n",
    "    clipped_mask = np.abs(raw_vals) >= max_abs_value\n",
    "\n",
    "    y_vals = np.arange(N)[::-1] + 1/8 * (3 - m) - 1/16\n",
    "\n",
    "    # Unclipped points\n",
    "    ax2.scatter(clipped_vals[~clipped_mask], y_vals[~clipped_mask],\n",
    "                label=labels_heuristics[m - 1], color=colours[m - 1], alpha=1.0, marker=markers[m-1],\n",
    "                edgecolors='black', linewidth=0.2)\n",
    "\n",
    "     # Split mask for arrow direction\n",
    "    right_mask = clipped_mask & (raw_vals > 0)\n",
    "    left_mask = clipped_mask & (raw_vals < 0)\n",
    "\n",
    "    # Right-pointing arrows (positive outliers)\n",
    "    ax2.scatter(clipped_vals[right_mask] -0.02, y_vals[right_mask],\n",
    "                color=colours[m - 1], marker='>', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None, s=25.0)\n",
    "    ax2.scatter(clipped_vals[right_mask] - 0.04, y_vals[right_mask],\n",
    "                color=colours[m - 1], marker='_', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None)\n",
    "\n",
    "    # Left-pointing arrows (negative outliers)\n",
    "    ax2.scatter(clipped_vals[left_mask] + 0.02, y_vals[left_mask],\n",
    "                color=colours[m - 1], marker='<', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None, s=25.0)\n",
    "    ax2.scatter(clipped_vals[left_mask] + 0.04, y_vals[left_mask],\n",
    "                color=colours[m - 1], marker='_', edgecolor=colours[m - 1], linewidth=1.0, zorder=5, label=None)\n",
    "\n",
    "\n",
    "ax2.axvline(0, color='lightgray', linestyle='-', linewidth=1.0, zorder=0)\n",
    "for y in gridlines_y:\n",
    "    ax2.axhline(y, color='lightgray', linestyle='--', linewidth=0.6, zorder=0)\n",
    "for y in bold_gridlines_y:\n",
    "    ax2.axhline(y, color='white', linestyle='-', linewidth=1.0, zorder=0)\n",
    "    ax2.axhline(y, color='black', linestyle='--', linewidth=1.0, zorder=0)\n",
    "\n",
    "ax2.set_xlabel(r'Relative value gap ($V_\\text{gap}$)')\n",
    "ax2.legend(framealpha=1.0, facecolor='white', edgecolor='black', loc='upper right')\n",
    "# ax2.set_title(\"Approximate solvers\")\n",
    "ax2.set_xlim(-max_abs_value, max_abs_value)\n",
    "# ax2.set_yticks()  # Remove y-axis ticks\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(path_figs, \"solvability_comparison.pdf\"), transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
