{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "\n",
    "with open(\"data/cascade_comparison_records.pkl\", \"rb\") as file:\n",
    "    ALL_RECORDS = pickle.load(file)\n",
    "\n",
    "DATA = ALL_RECORDS\n",
    "df = pd.DataFrame(DATA)\n",
    "df['cascade_symbol'] = [ \"->\".join([str(y) for y in x]) for x in df['cascade']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def compute_minmax(data_cts, data_grid):\n",
    "    overall_min = max(\n",
    "        np.min([ x[0] for x in data_cts ]),\n",
    "        np.min([ x[0] for x in data_grid ])\n",
    "    )\n",
    "    overall_max = min(\n",
    "        np.max([ x[0] for x in data_grid ]), \n",
    "        np.max([ x[0] for x in data_cts ])\n",
    "    )\n",
    "    return overall_max-overall_min\n",
    "\n",
    "for i in range(int(len(df)/2)):\n",
    "    assert df.loc[2*i, 'method'] == 'continuous_optimization'\n",
    "    assert df.loc[2*i+1, 'method'] == 'gridsearch'\n",
    "    assert df.loc[2*i, 'benchmark'] == df.loc[2*i+1, 'benchmark']\n",
    "    assert df.loc[2*i, 'cascade_symbol'] == df.loc[2*i+1, 'cascade_symbol']\n",
    "    cts_data = df.loc[2*i, 'data'] \n",
    "    grid_data = df.loc[2*i + 1, 'data']\n",
    "    \n",
    "    df.loc[2*i, 'auc_norm'] = compute_minmax(cts_data, grid_data)\n",
    "    df.loc[2*i+1, 'auc_norm'] = compute_minmax(cts_data, grid_data)\n",
    "\n",
    "df['auc'] = 1 - df['performance']/df['auc_norm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats\n",
    "\n",
    "### WILCOXON TESTS BY CASCADE LEN\n",
    "\n",
    "def run_paired_wilcoxon_tests_by_len(df):\n",
    "    results = []\n",
    "    \n",
    "    # For each cascade length\n",
    "    for cascade_len in df['cascade_len'].unique():\n",
    "        # Get all data for this cascade length\n",
    "        mask = (df['cascade_len'] == cascade_len)\n",
    "        data = df[mask]\n",
    "        \n",
    "        # Create paired samples using benchmark and cascade_symbol\n",
    "        continuous = data[data['method'] == 'continuous_optimization'].set_index(['benchmark', 'cascade_symbol'])['auc']\n",
    "        gridsearch = data[data['method'] == 'gridsearch'].set_index(['benchmark', 'cascade_symbol'])['auc']\n",
    "        \n",
    "        # Make sure we have the same pairs in both\n",
    "        common_pairs = continuous.index.intersection(gridsearch.index)\n",
    "        continuous = continuous[common_pairs]\n",
    "        gridsearch = gridsearch[common_pairs]\n",
    "        \n",
    "        # Run test if we have data\n",
    "        if len(continuous) > 0 and len(gridsearch) > 0:\n",
    "            stat, pval = stats.wilcoxon(gridsearch, continuous, alternative='greater')\n",
    "            results.append({\n",
    "                'cascade_len': cascade_len,\n",
    "                'n_pairs': len(continuous),\n",
    "                'statistic': stat,\n",
    "                'pvalue': pval\n",
    "            })\n",
    "    \n",
    "    return pd.DataFrame(results)\n",
    "\n",
    "run_paired_wilcoxon_tests_by_len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot reduction in AUC by cascade length\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def compute_reduction_stats(df, cascade_len):\n",
    "    continuous = df[(df['method'] == 'continuous_optimization') & (df['cascade_len'] == cascade_len)]['auc']\n",
    "    gridsearch = df[(df['method'] == 'gridsearch') & (df['cascade_len'] == cascade_len)]['auc']\n",
    "    \n",
    "    pct_changes = 100 * (continuous.values - gridsearch.values) / gridsearch.values\n",
    "    n = len(pct_changes)\n",
    "    return np.mean(pct_changes), np.std(pct_changes, ddof=1)/np.sqrt(n)\n",
    "\n",
    "cascade_lengths = sorted(df['cascade_len'].unique())\n",
    "means = []\n",
    "sems = []\n",
    "for length in cascade_lengths:\n",
    "    mean, sem = compute_reduction_stats(df, length)\n",
    "    means.append(mean)\n",
    "    sems.append(sem)\n",
    "\n",
    "# Plot line first\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(cascade_lengths, means, '-')\n",
    "\n",
    "# Then plot points with different fills\n",
    "for i, (x, y, sem) in enumerate(zip(cascade_lengths, means, sems)):\n",
    "    if i == 0:  # first point (not significant)\n",
    "        ax.scatter(x, y, color='white', edgecolor='tab:blue', zorder=3)\n",
    "    else:  # other points\n",
    "        ax.scatter(x, y, color='tab:blue', zorder=3)\n",
    "\n",
    "# Add error bars\n",
    "ax.errorbar(cascade_lengths, means, yerr=sems, fmt='none', capsize=3, color='tab:blue')\n",
    "\n",
    "ax.axhline(y=0, linestyle='dashed', color='gray')\n",
    "ax.set_xlabel(\"Cascade Length\", fontsize=14)\n",
    "ax.set_ylabel(\"\\\\% Reduction in AUC\", fontsize=14)\n",
    "# ax.set_title(\"\\% Reduction in AUC using Continuous Optimization\")\n",
    "\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "plt.tight_layout()\n",
    "\n",
    "ax.set_xticks([2,3,4,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### PAIRED WILCOXON TEST BY BENCHMARK\n",
    "\n",
    "def analyze_auc_by_benchmark(df):\n",
    "   results = []\n",
    "   \n",
    "   for benchmark in df['benchmark'].unique():\n",
    "       # Get all data for this benchmark\n",
    "       data = df[df['benchmark'] == benchmark]\n",
    "       \n",
    "       # Get AUC for each method\n",
    "       continuous = data[data['method'] == 'continuous_optimization'].set_index(['cascade_len', 'cascade_symbol'])['auc']\n",
    "       gridsearch = data[data['method'] == 'gridsearch'].set_index(['cascade_len', 'cascade_symbol'])['auc']\n",
    "       \n",
    "       # Make sure we have the same cases\n",
    "       common_pairs = continuous.index.intersection(gridsearch.index)\n",
    "       continuous = continuous[common_pairs]\n",
    "       gridsearch = gridsearch[common_pairs]\n",
    "       \n",
    "       # Compute percentage differences\n",
    "       percent_diffs = 100 * (continuous - gridsearch) / gridsearch\n",
    "       \n",
    "       # Conduct Wilcoxon test\n",
    "       stat, pval = stats.wilcoxon(gridsearch, continuous, alternative='greater')\n",
    "       \n",
    "       results.append({\n",
    "           'benchmark': benchmark,\n",
    "           'mean_percent_change': percent_diffs.mean(),\n",
    "           'continuous_auc': continuous.mean(),\n",
    "           'grid_auc': gridsearch.mean(),\n",
    "           'n_pairs': len(continuous),\n",
    "           'wilcoxon_stat': stat,\n",
    "           'p_value': pval\n",
    "       })\n",
    "   \n",
    "   return pd.DataFrame(results)\n",
    "\n",
    "# Create and display results\n",
    "results = analyze_auc_by_benchmark(df)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.groupby(by=['method', 'cascade_len'])['time'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot runtime comparison\n",
    "\n",
    "import seaborn as sns\n",
    "from matplotlib import rcParams\n",
    "\n",
    "# Set seaborn style first\n",
    "sns.set_style(\"white\")\n",
    "sns.set_context(\"paper\", font_scale=1.0)\n",
    "\n",
    "# Then matplotlib settings\n",
    "rcParams['text.usetex'] = False #True\n",
    "rcParams['font.family'] = 'serif'\n",
    "rcParams['font.serif'] = ['Computer Modern Roman']\n",
    "rcParams['font.size'] = 10\n",
    "\n",
    "runtime_by_method = df.groupby(by=['method', 'cascade_len'])['time'].mean()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5,4))\n",
    "\n",
    "cascade_lens = [2,3,4,5]\n",
    "\n",
    "color_grid = 'gray'\n",
    "color_cts = 'tab:blue'\n",
    "ax.scatter(cascade_lens, runtime_by_method.loc['gridsearch'], color=color_grid)\n",
    "ax.plot(cascade_lens, runtime_by_method.loc['gridsearch'], color=color_grid)\n",
    "\n",
    "ax.fill_between(cascade_lens, \n",
    "                runtime_by_method.loc['gridsearch'] - df.groupby(by=['method', 'cascade_len'])['time'].std().loc['gridsearch'],  # Lower bound\n",
    "                runtime_by_method.loc['gridsearch'] + df.groupby(by=['method', 'cascade_len'])['time'].std().loc['gridsearch'],  # Upper bound\n",
    "                color=color_grid, \n",
    "                alpha=0.2,  # Transparency\n",
    "                label='±1 std')\n",
    "\n",
    "ax.text(4.1, 38, \"grid search\", color=color_grid, fontweight='bold').set_rotation(38)\n",
    "ax.scatter(cascade_lens, runtime_by_method.loc['continuous_optimization'], color=color_cts)\n",
    "ax.plot(cascade_lens, runtime_by_method.loc['continuous_optimization'], color=color_cts)\n",
    "ax.text(4.2, 5, \"continuous\", color=color_cts, fontweight='bold').set_rotation(7)\n",
    "\n",
    "ax.fill_between(cascade_lens, \n",
    "                runtime_by_method.loc['continuous_optimization'] - df.groupby(by=['method', 'cascade_len'])['time'].std().loc['continuous_optimization'],  # Lower bound\n",
    "                runtime_by_method.loc['continuous_optimization'] + df.groupby(by=['method', 'cascade_len'])['time'].std().loc['continuous_optimization'],  # Upper bound\n",
    "                color=color_cts, \n",
    "                alpha=0.2,  # Transparency\n",
    "                label='±1 std')\n",
    "\n",
    "ax.set_xlabel(\"Cascade Length\", fontsize=14)\n",
    "ax.set_xticks(cascade_lens)\n",
    "ax.set_xticklabels(cascade_lens)\n",
    "ax.set_ylabel(\"Runtime (s)\", fontsize=14)\n",
    "\n",
    "ax.set_yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
