{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "from bon.utils import utils\n",
    "from bon.utils.plot_utils import set_plot_style\n",
    "from bon.utils.power_law_simple import fit_power_law\n",
    "from bon.utils.powerlaw_plot_utils import adjust_color, plot_fitted_asr, plot_mean_and_std\n",
    "from bon.utils.shotgun_utils import calculate_asr_trajectories, process_powerlaw_data\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_palette = set_plot_style()\n",
    "\n",
    "def convert_to_percentages(data):\n",
    "    return [[value * 100 for value in sublist] for sublist in data]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "direct_requests =  Path(\"./data/direct_request.jsonl\")\n",
    "df_direct = utils.load_jsonl_df(direct_requests)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_powerlaw_data(models, shotgun_type=\"text\", num_steps=8, train_split=1000, num_concurrent_k=10):\n",
    "    dfs = {}\n",
    "    asrs = {}\n",
    "    train_asrs = {}\n",
    "    perc_asrs = {}\n",
    "\n",
    "    for model_name, model_path in models.items():\n",
    "\n",
    "        print(f\"\\nProcessing {model_name}\")\n",
    "                \n",
    "        dfs[model_name] = process_powerlaw_data(\n",
    "            model_path, df_direct, \n",
    "            \"direct_request_search_steps.jsonl\",\n",
    "            num_steps, num_concurrent_k, 159, shotgun_type=shotgun_type, \n",
    "            pad_to_n_steps=False, overwrite=False,\n",
    "        )\n",
    "        \n",
    "        asrs[model_name] = calculate_asr_trajectories(\n",
    "            dfs[model_name], model_path, num_repeats=10\n",
    "        )\n",
    "        train_asrs[model_name] = calculate_asr_trajectories(\n",
    "            dfs[model_name], model_path, num_repeats=10, train_num_samples=train_split, num_samples=train_split\n",
    "        )\n",
    "        \n",
    "        perc_asrs[model_name] = convert_to_percentages(asrs[model_name])\n",
    "    \n",
    "    experiments = {\n",
    "        model_name: (perc_asrs[model_name], asrs[model_name], train_asrs[model_name], train_asrs[model_name], dfs[model_name])\n",
    "        for model_name in models.keys()\n",
    "    }\n",
    "    \n",
    "    return experiments\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir_text = Path(\"./exp/bon/text\")\n",
    "root_dir_audio = Path(\"./exp/bon/audio\")\n",
    "root_dir_image = Path(\"./exp/bon/image\")\n",
    "\n",
    "text_models = {\n",
    "    \"GPT-4o-Mini\": root_dir_text / \"gpt-4o-mini\",\n",
    "}\n",
    "\n",
    "audio_models = {\n",
    "    \"Gemini Flash\": root_dir_audio / \"gemini-1.5-flash-001\",\n",
    "}\n",
    "\n",
    "vision_models = {\n",
    "    \"GPT-4o-Mini\": root_dir_image / \"gpt-4o-mini\",\n",
    "}\n",
    "\n",
    "text_experiments = get_powerlaw_data(text_models, shotgun_type=\"text\", num_steps=8)\n",
    "vision_experiments = get_powerlaw_data(vision_models, shotgun_type=\"image\", num_steps=8)\n",
    "audio_experiments = get_powerlaw_data(audio_models, shotgun_type=\"audio\", num_steps=8)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_mapping = {\n",
    "    \"Claude 3 Opus\": adjust_color(color_palette[0], 0.5),\n",
    "    \"Claude 3.5 Sonnet\": adjust_color(color_palette[0], -0.1),\n",
    "    \"GPT-4o\": adjust_color(color_palette[2], -0.1),\n",
    "    \"GPT-4o-Mini\": adjust_color(color_palette[2], 0.5),\n",
    "    \"Gemini Flash\": adjust_color(color_palette[3], 0.5),\n",
    "    \"Gemini Pro\": adjust_color(color_palette[3], -0.1),\n",
    "    \"Llama3 8B\": color_palette[6],\n",
    "    \"Circuit Breaking\": color_palette[4],\n",
    "    \"DiVA\": color_palette[8],\n",
    "    \"Cygnet\": color_palette[5],\n",
    "    \"Cygnet w/ system prompt\": color_palette[5]\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mega Figure 1 Simplified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "models_to_plot = [\"GPT-4o-Mini\", \"Gemini Flash\"]\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(5.5, 2.5), dpi=600)\n",
    "\n",
    "all_asrs = []\n",
    "for i, (modality_exp, modality_name) in enumerate(zip([text_experiments, vision_experiments, audio_experiments], [\"Text\", \"Vision\", \"Audio\"])):\n",
    "    for j, (exp_name, asrs) in enumerate(modality_exp.items()):\n",
    "        if exp_name not in models_to_plot:\n",
    "            continue\n",
    "\n",
    "        # Prepare the data\n",
    "        prop_asr = np.array(asrs[1])\n",
    "        prop_asr_mean = np.mean(prop_asr, axis=0)\n",
    "        prop_asr_std = np.std(prop_asr, axis=0)\n",
    "\n",
    "        perc_asr = np.array(asrs[0])\n",
    "        perc_asr_mean = np.mean(perc_asr, axis=0)\n",
    "        perc_asr_std = np.std(perc_asr, axis=0)\n",
    "\n",
    "        steps = np.arange(1, len(prop_asr_mean)+1)\n",
    "        color = color_mapping[exp_name]\n",
    "\n",
    "        linewidth = 1.15\n",
    "\n",
    "        # Plot regular fit in top row\n",
    "        plot_mean_and_std(axes[i], perc_asr_mean, perc_asr_std, steps, exp_name=\"\", log_scale_x=False, log_scale_y=False, color=color, plot_std_err=True, linewidth=linewidth)\n",
    "\n",
    "        print(f\"Fitting {exp_name}, {modality_name}, final ASR: {perc_asr_mean[-1]:.2f}%, num steps: {len(steps)}\")\n",
    "        all_asrs.append(perc_asr_mean[-1])\n",
    "\n",
    "    \n",
    "    # Set properties for regular fit (top row)\n",
    "    ax = axes[i]\n",
    "    if i == 0:\n",
    "        ax.set_ylabel(\"ASR (%)\")\n",
    "    else:\n",
    "        ax.set_ylabel(\"\")\n",
    "    ax.set_title(modality_name)\n",
    "    ax.tick_params(axis='both', which='major', labelsize=8)\n",
    "    # ax.set_yticks([0, 20, 40, 60, 80, 100], [0, 20, 40, 60, 80, 100])\n",
    "    # if i != 0:\n",
    "    #     ax.set_xticks([0, 2000, 4000, 6000], [\"0\", \"2k\", \"4k\", \"6k\"])\n",
    "    # else:\n",
    "    #     ax.set_xticks([0, 2500, 5000, 7500, 10000], [0, \"2.5k\", \"5k\", \"7.5k\", \"10k\"])\n",
    "    # ax.set_xlim(right=7200 if i > 0 else 10000)\n",
    "    # ax.set_ylim(0, 102)\n",
    "    ax.set_xlabel(\"N\")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Add legend elements\n",
    "legend_elements = []\n",
    "for exp_name, color in color_mapping.items():\n",
    "    if exp_name in models_to_plot:\n",
    "        legend_elements.append(Line2D([0], [0], color=color, linestyle=\"-\", lw=1, label=f\"{exp_name}\"))\n",
    "\n",
    "# Create a single legend below the plots\n",
    "fig.legend(handles=legend_elements, \n",
    "           loc='lower center', \n",
    "           ncol=4, \n",
    "           bbox_to_anchor=(0.52, -0.05),\n",
    "           columnspacing=1,\n",
    "           handlelength=1.5,\n",
    "           handletextpad=0.5,\n",
    "           borderpad=0.5,\n",
    "           labelspacing=0.5)\n",
    "\n",
    "plt.subplots_adjust(bottom=0.28)  # Increased bottom margin\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "print(np.mean(all_asrs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fitting Power Law"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 3, figsize=(5.5, 2.5), dpi=600, sharey=True)\n",
    "\n",
    "models_to_plot = [\"GPT-4o-Mini\", \"Gemini Flash\"]\n",
    "\n",
    "for i, (modality_exp, modality_name) in enumerate(zip([text_experiments, vision_experiments, audio_experiments], [\"Text\", \"Vision\", \"Audio\"])):\n",
    "    ax = axes[i]\n",
    "    for j, (exp_name, asrs) in enumerate(modality_exp.items()):\n",
    "        if exp_name not in models_to_plot:\n",
    "            continue\n",
    "        print(f\"Fitting {exp_name}\")\n",
    "\n",
    "        # Prepare the data\n",
    "        prop_asr = np.array(asrs[1])\n",
    "        prop_asr_mean = np.mean(prop_asr, axis=0)\n",
    "        prop_asr_std = np.std(prop_asr, axis=0)\n",
    "\n",
    "        perc_asr = np.array(asrs[0])\n",
    "        perc_asr_mean = np.mean(perc_asr, axis=0)\n",
    "        perc_asr_std = np.std(perc_asr, axis=0)\n",
    "\n",
    "        steps = np.arange(1, len(prop_asr_mean)+1)\n",
    "        color = color_mapping[exp_name]\n",
    "\n",
    "        linewidth = 1.15\n",
    "        \n",
    "        # Plot log space fit\n",
    "        plot_mean_and_std(ax, prop_asr_mean, prop_asr_std, steps, exp_name=\"\", log_scale_x=True, log_scale_y=True, color=color, plot_std_err=True, linewidth=linewidth, std_scale_factor=1, use_line=False)\n",
    "\n",
    "        # Fit the model\n",
    "        try:\n",
    "            params = fit_power_law(x=steps, y=prop_asr, fit_type=\"linear_log_spacing\", skip_first_points=5)\n",
    "        except Exception:\n",
    "            params = fit_power_law(x=steps, y=prop_asr, fit_type=\"linear\", skip_first_points=5)\n",
    "        \n",
    "        # Plot fitted ASR for log space fit\n",
    "        plot_fitted_asr(ax, steps, params, color=color, exp_name=\"\", log_scale_x=True, log_scale_y=True,linewidth=linewidth)\n",
    "\n",
    "        print(f\"{exp_name}, {modality_name}, {params}\")\n",
    "\n",
    "    # Create second y-axis for ASR percentage\n",
    "    if i == 2:\n",
    "        ax2 = ax.twinx()\n",
    "        ax2.set_yscale('log')\n",
    "        # ax2.set_ylim(0.03, 5)\n",
    "\n",
    "        # Set ticks for both axes\n",
    "        yticks = ax.get_yticks()\n",
    "        ax.set_yticklabels([f\"{y:.1f}\" for y in yticks])\n",
    "        ax2.set_yticklabels([f\"{np.exp(-y)*100:.0f}%\" for y in yticks])\n",
    "\n",
    "        # Set titles and labels\n",
    "        ax2.set_ylabel(\"ASR (%)\")\n",
    "        ax2.tick_params(axis='both', which='major')\n",
    "        \n",
    "\n",
    "    ax.set_title(modality_name)\n",
    "    if i == 0:\n",
    "        ax.set_ylabel(\"-log(ASR)\")\n",
    "    ax.tick_params(axis='both', which='major', labelsize=8)\n",
    "    # ax.set_yticks([0.01, 0.1, 1], [0.01, 0.1, 1])\n",
    "    if i == 1:\n",
    "        ax.set_xlabel(\"N\")\n",
    "    # ax.set_xlim(left=10)\n",
    "    # ax.set_xlim(right=7200)\n",
    "    \n",
    "    ax.set_ylim(0.03, 5)\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Add legend elements\n",
    "legend_elements = []\n",
    "for exp_name, color in color_mapping.items():\n",
    "    if exp_name in models_to_plot:\n",
    "        legend_elements.append(Line2D([0], [0], color=color, linestyle=\"-\", lw=1, label=f\"{exp_name}\"))\n",
    "\n",
    "# Create a single legend below the plots\n",
    "fig.legend(handles=legend_elements, \n",
    "           loc='lower center', \n",
    "           ncol=4, \n",
    "           bbox_to_anchor=(0.52, -0.05),  # Moved up slightly\n",
    "           columnspacing=1,\n",
    "           handlelength=1.5,\n",
    "           handletextpad=0.5,\n",
    "           borderpad=0.5,\n",
    "           labelspacing=0.5)  # Reduced font size slightly\n",
    "\n",
    "\n",
    "plt.subplots_adjust(bottom=0.28)  # Increased bottom margin\n",
    "plt.show()\n",
    "plt.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Micromamba Env",
   "language": "python",
   "name": "almj"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
