{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from bon.utils import utils\n",
    "from bon.utils.shotgun_utils import get_powerlaw_baseline, process_powerlaw_data\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load direct requests and audio/ text repeated sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "direct_requests =  Path(\"./data/direct_request.jsonl\")\n",
    "df_direct = utils.load_jsonl_df(direct_requests)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Baselines\n",
    "root_dir_baseline = Path(\"./exp/figure1/baselines/audio-only\")\n",
    "flash_pattern = \"flash_direct_request_high_temp1.0_step*.jsonl\"\n",
    "flash_files = glob.glob(str(root_dir_baseline / flash_pattern))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flash_baselines = []\n",
    "for i, file in enumerate(flash_files):\n",
    "    flash_df = utils.load_jsonl_df(file)\n",
    "    flash_df[\"step\"] = i\n",
    "    flash_baselines.append(flash_df)\n",
    "df_flash_baseline = pd.concat(flash_baselines)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flash_grp = df_flash_baseline.groupby([\"step\"])['proportion_flagged_audio'].sum().reset_index().sort_values('proportion_flagged_audio', ascending = False)\n",
    "flash_grp['proportion_flagged_audio'] = flash_grp['proportion_flagged_audio']/159"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dr_idx_map = dict(zip(df_direct.rewrite.tolist(), df_direct.index.tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_flash_baseline = pd.DataFrame(get_powerlaw_baseline(df_flash_baseline, dr_idx_map))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load and Process Shotgun"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = Path(\"./exp/figure1\")\n",
    "aug_attack_flash = root_dir / \"gemini-1.5-flash-001\"\n",
    "df_flash = process_powerlaw_data(aug_attack_flash, df_direct, \"direct_request_search_steps.jsonl\", 8, 10, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@file_cache()\n",
    "def calculate_asr_trajectories(df, num_repeats=20, num_tiles = 1):\n",
    "    asrs = []\n",
    "    for num_repeat in range(num_repeats):\n",
    "        shuffled_df = df.groupby(\"i\").apply(lambda x: x.sample(frac=1, random_state=num_repeat)).reset_index(drop=True)\n",
    "\n",
    "        # size (N * T) -> (T, N)\n",
    "        # N = number of unique i\n",
    "        # T = number of steps (len(df) / N)\n",
    "        arr = np.stack(shuffled_df.groupby(\"i\")[\"flagged\"].apply(lambda x: x.to_numpy())).T\n",
    "        T, N = arr.shape\n",
    "        arr = np.tile(arr, (num_tiles,1))\n",
    "\n",
    "        for col in range(arr.shape[1]):\n",
    "            np.random.shuffle(arr[:, col])\n",
    "\n",
    "        # Use cumsum to efficiently calculate the cumulative flagged count\n",
    "        cumulative_flagged = np.cumsum(arr, axis=0)\n",
    "        idx_flagged = cumulative_flagged > 0\n",
    "\n",
    "        # Calculate ASR trajectory efficiently\n",
    "        asr_trajectory = np.sum(idx_flagged, axis=1) / N\n",
    "\n",
    "        asr_trajectory = asr_trajectory[:T]\n",
    "        asrs.append(asr_trajectory.tolist())\n",
    "    \n",
    "    print(T,N)\n",
    "    print(len(asrs))\n",
    "    return asrs\n",
    "\n",
    "asrs_flash = calculate_asr_trajectories(df_flash, 100, num_tiles=10)\n",
    "baseline_flash = calculate_asr_trajectories(df_flash_baseline, 100, num_tiles=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define plotting functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bon.utils.power_law import (\n",
    "    exp_power_law,\n",
    "    exp_power_law_single_term_no_constant,\n",
    "    o1_simple_fit_power_law_no_constant,\n",
    "    power_law,\n",
    "    power_law_single_term_no_constant,\n",
    ")\n",
    "\n",
    "\n",
    "def plot_mean_and_std(ax, asr_mean, asr_std, steps, log_scale=True, epsilon=1e-5, exp_name=\"ASR\", color=None):\n",
    "    if log_scale:\n",
    "        line, = ax.plot(steps, -np.log(asr_mean), label=exp_name, color=color)\n",
    "        ax.fill_between(steps, \n",
    "                         -np.log(np.maximum(asr_mean - asr_std, epsilon)), \n",
    "                         -np.log(np.maximum(asr_mean + asr_std, epsilon)), \n",
    "                         alpha=0.3)\n",
    "        ax.set_ylabel(\"-log(ASR)\")\n",
    "        ax.set_xscale(\"log\")\n",
    "        ax.set_yscale(\"log\")\n",
    "    else:\n",
    "        line, = ax.plot(steps, asr_mean, label=exp_name, color=color)\n",
    "        ax.fill_between(steps, \n",
    "                         (asr_mean - asr_std), \n",
    "                         (asr_mean + asr_std), \n",
    "                         alpha=0.3)\n",
    "        ax.set_ylabel(\"ASR\")\n",
    "    \n",
    "    ax.set_xlabel(\"N\")\n",
    "    ax.set_xlim(left=1)\n",
    "\n",
    "\n",
    "def plot_fitted_asr(ax, steps, params, color, log_scale=True, exp_name=\"\", use_label=True):\n",
    "    if len(params) == 2:\n",
    "        fitted_asr = power_law_single_term_no_constant(steps, *params) if log_scale else exp_power_law_single_term_no_constant(steps, *params)\n",
    "        linestyle = '--'\n",
    "        label = f\"{exp_name}: fitted\"\n",
    "    elif len(params) == 3:\n",
    "        fitted_asr = power_law(steps, *params) if log_scale else exp_power_law(steps, *params)\n",
    "        linestyle = '--'\n",
    "        label = f'{exp_name}: {params[0]:.2f}x^{-params[1]:.2f} + {params[2]:.2f}'\n",
    "    else:\n",
    "        raise ValueError(f\"Unexpected number of params: {len(params)}\")\n",
    "    \n",
    "    if use_label:\n",
    "        ax.plot(steps, fitted_asr, linestyle=linestyle, color=color, label=label)\n",
    "    else:\n",
    "        ax.plot(steps, fitted_asr, linestyle=linestyle, color=color)\n",
    "\n",
    "    ax.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "\n",
    "experiments = {\n",
    "    \"flash\": asrs_flash,\n",
    "    \"flash_baseline\": baseline_flash,\n",
    "}\n",
    "\n",
    "method_name = \"single_term_no_constant\"\n",
    "method_func = o1_simple_fit_power_law_no_constant\n",
    "start_n = 0\n",
    "train_split = None\n",
    "\n",
    "# Create a new figure and axis for each log_scale value\n",
    "fig, axs = plt.subplots(1, 2, figsize=(15, 5))\n",
    "\n",
    "# Define a color palette for color blindness\n",
    "color_palette = sns.color_palette(\"tab10\")\n",
    "\n",
    "for i, log_scale in enumerate([False, True]):\n",
    "    ax = axs[i]\n",
    "    for j, (exp_name, asrs) in enumerate(experiments.items()):\n",
    "\n",
    "\n",
    "        # Prepare the data\n",
    "        asr = np.array(asrs)\n",
    "        asr_mean = np.mean(asr, axis=0)\n",
    "        asr_std = np.std(asr, axis=0)\n",
    "        steps = np.arange(1, len(asr_mean)+1)\n",
    "        color = color_palette[j % len(color_palette)]\n",
    "        plot_mean_and_std(ax, asr_mean, asr_std, steps, exp_name=exp_name, log_scale=log_scale, color=color)\n",
    "\n",
    "        param_steps = steps[start_n:]\n",
    "        param_asr = asr[:,start_n:]\n",
    "        param_asr_mean = np.mean(param_asr, axis=0)\n",
    "\n",
    "        # Get the train and test splits\n",
    "        if train_split is not None:\n",
    "            train_steps = param_steps[:train_split]\n",
    "            train_asr = param_asr[:,:train_split]\n",
    "            train_asr_mean = np.mean(train_asr, axis=0)\n",
    "            test_split=train_split\n",
    "        else:\n",
    "            train_steps = param_steps\n",
    "            train_asr = param_asr\n",
    "            train_asr_mean = param_asr_mean\n",
    "            test_split=0\n",
    "\n",
    "        # Check if train_asr is empty\n",
    "        if len(train_asr) == 0 or train_asr.size == 0:\n",
    "            print(f\"Skipping {exp_name} with {method_name} due to empty train_asr\")\n",
    "            continue\n",
    "\n",
    "        # Fit the model\n",
    "        if \"chinchilla\" in method_name:\n",
    "            params = method_func(x=train_steps, y=-np.log(train_asr_mean + 1e-5))\n",
    "        else:\n",
    "            params = method_func(x=train_steps, y=train_asr)\n",
    "\n",
    "        steps_fitting = np.arange(1, 10000)\n",
    "        plot_fitted_asr(ax, steps_fitting, params, color=color, exp_name=exp_name, log_scale=log_scale, use_label=False)\n",
    "\n",
    "    # Add shaded area to show start and end of train split    \n",
    "    train_split = train_split if train_split else 0\n",
    "    start_n = start_n if start_n else 0\n",
    "    ax.axvspan(start_n, train_split + start_n, color='lightgrey', alpha=0.4)\n",
    "\n",
    "    # ax.set_title(f'Train Split: {train_split if train_split else \"All\"}, Start N: {start_n}, Log Scale: {log_scale}')\n",
    "    ax.grid(True, which=\"both\", ls=\"-\", alpha=0.2)\n",
    "    if log_scale:\n",
    "        ax.set_xlim(left=10)\n",
    "        ax.set_ylim(top=5)\n",
    "        ax.set_yticks([0.1, 0.2, 0.5, 1, 2, 5])\n",
    "        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x:g}'))\n",
    "        ax.set_xticks([10, 100, 1000, 10000])\n",
    "        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x:g}'))\n",
    "# Move legend outside the plot\n",
    "axs[0].legend(bbox_to_anchor=(1, -0.15), loc='upper center', borderaxespad=0., ncol=len(experiments), fontsize=14)\n",
    "# remove legend from second plot\n",
    "axs[1].legend().remove()\n",
    "\n",
    "# Set font size for all axes ticks and labels\n",
    "for ax in axs:\n",
    "    ax.tick_params(axis='both', which='major', labelsize=10)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=10)\n",
    "    ax.xaxis.label.set_size(14)\n",
    "    ax.yaxis.label.set_size(14)\n",
    "\n",
    "# tight layout\n",
    "# plt.tight_layout()\n",
    "\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Micromamba Env",
   "language": "python",
   "name": "almj"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
