{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tables and Figures\n",
    "This notebook contains the code we used to generate some of the tables and figures in the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import glob\n",
    "import json\n",
    "import textwrap\n",
    "\n",
    "import requests\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "sns.set('talk')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantitative Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Steerability of Image Generation Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('data/summary_metrics.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sorted = df.sort_values('rating_avg')\n",
    "df_sorted = df_sorted[df_sorted['model'] != 'Average']  # Remove average if present\n",
    "x = np.arange(len(df_sorted))\n",
    "width = 0.35\n",
    "\n",
    "# Create figure with two subplots\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))\n",
    "\n",
    "# Left panel - Ratings\n",
    "rating_bars = ax1.bar(x, df_sorted['rating_avg'], \n",
    "                     yerr=df_sorted['rating_sem'],\n",
    "                     capsize=5,\n",
    "                     label='All attempts')\n",
    "ax1.set_ylabel('Similarity to goal image')\n",
    "ax1.set_xticks(x)\n",
    "ax1.axhline(10, color='darkred', linestyle='--', linewidth=1, label='Perfect similarity')\n",
    "ax1.set_ylim(0, 11)\n",
    "ax1.set_xticklabels(df_sorted['model'], rotation=45, ha='right')\n",
    "ax1.legend(loc='upper right')\n",
    "# Right panel - Attempts\n",
    "color_palette = sns.color_palette()\n",
    "first_bars = ax2.bar(x - width/2, df_sorted['first_attempt_avg'], width, \n",
    "                    label='First attempt', \n",
    "                    yerr=df_sorted['first_attempt_sem'],\n",
    "                    capsize=5, color=color_palette[1])\n",
    "last_bars = ax2.bar(x + width/2, df_sorted['last_attempt_avg'], width, \n",
    "                    label='Last attempt',\n",
    "                    yerr=df_sorted['last_attempt_sem'],\n",
    "                    capsize=5, color=color_palette[2])\n",
    "# ax2.set_ylabel('Similarity to goal image\\n (human rating)')\n",
    "ax2.set_xticks(x)\n",
    "ax2.set_xticklabels(df_sorted['model'], rotation=45, ha='right')\n",
    "\n",
    "# Add a dashed line at 10 for perfect similarity in right panel\n",
    "ax2.axhline(10, color='darkred', linestyle='--', linewidth=1, label='Perfect similarity')\n",
    "ax2.set_ylim(0, 11)\n",
    "ax2.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('all_attempts_barchart.pdf', bbox_inches='tight', dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[['model', 'last_over_first_avg', 'last_over_first_sem', 'first_matches_goal_avg', 'first_matches_goal_sem', 'last_matches_goal_avg', 'last_matches_goal_sem']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text Steering vs Image Steering vs. Image Steering with RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_attempts= [1, 2, 3, 4, 5]\n",
    "means_text_steering= [0.0, 0.002495847021540006, 0.01192834104100863, 0.02294766716659069, 0.02898228106399377]\n",
    "sems_text_steering= [0.0, 0.014831217294403013, 0.015463105048206157, 0.014722723334135723, 0.01510288011453466]\n",
    "means_random_stepsize= [0.0, 0.02797472513980747, 0.04062674078882111, 0.04724131867011882, 0.052637251082414425]\n",
    "sems_random_stepsize= [0.0, 0.004340646216735713, 0.004960735710954918, 0.005284981183033871, 0.005362842282494922]\n",
    "means_new_rl= [0.0, 0.04049745921430917, 0.05952271690656399, 0.06805097260351839, 0.0715540506202599]\n",
    "sems_new_rl= [0.0, 0.0057135848847513545, 0.006633413178525755, 0.00707276899252904, 0.007134218214304489]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(\"talk\")\n",
    "plt.figure(figsize=(8.6, 6))\n",
    "plt.errorbar(unique_attempts, means_text_steering, yerr=sems_text_steering, fmt='o-', capsize=5, label='Text Steering')\n",
    "plt.errorbar(unique_attempts, means_random_stepsize, yerr=sems_random_stepsize, fmt='o-', capsize=5, label='Image Steering (random)')\n",
    "plt.errorbar(unique_attempts, means_new_rl, yerr=sems_new_rl, fmt='o-', capsize=5, label='Image Steering (RLHS)')\n",
    "plt.xlabel('Attempt Number', fontsize=23)\n",
    "plt.ylabel('Improvement over first attempt', fontsize=23)\n",
    "plt.xticks(unique_attempts)\n",
    "# Set label sizes larger\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "plt.grid(True)\n",
    "plt.legend(loc='upper left')\n",
    "# legend = plt.legend(bbox_to_anchor=(0.32, 1.55), loc='upper center', ncol=1)\n",
    "plt.tight_layout()\n",
    "plt.savefig('improvement_by_attempt.pdf', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image Steering with RL Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_stay = np.array([[0.        , 0.        , 0.        , 0.        , 0.        ,\n",
    "        0.        , 0.        , 0.        , 0.        , 0.        ],\n",
    "       [0.04642136, 0.05038638, 0.04917789, 0.05267864, 0.05179059,\n",
    "        0.05574815, 0.05335869, 0.05451705, 0.0551046 , 0.05405061],\n",
    "       [0.08927953, 0.08747603, 0.08821033, 0.0898371 , 0.08980018,\n",
    "        0.08727004, 0.09088007, 0.08836148, 0.089135  , 0.09181658],\n",
    "       [0.09804241, 0.09788405, 0.09963065, 0.10032585, 0.09884076,\n",
    "        0.09839644, 0.0972433 , 0.09653404, 0.09663419, 0.09819333]])\n",
    "X_not_stay = np.array([[0.09855232, 0.10071277, 0.10350775, 0.10599847, 0.10902616,\n",
    "        0.11065273, 0.11191053, 0.11209451, 0.11139376, 0.11168785],\n",
    "       [0.13162178, 0.13280695, 0.13582534, 0.13732035, 0.13839655,\n",
    "        0.1380477 , 0.13738207, 0.13916574, 0.13703005, 0.13656748],\n",
    "       [0.12548979, 0.12919237, 0.12751702, 0.13099927, 0.133716  ,\n",
    "        0.12982938, 0.12979796, 0.13273502, 0.13228459, 0.13002534],\n",
    "       [0.12670392, 0.12695364, 0.12751615, 0.1273911 , 0.12708219,\n",
    "        0.12593613, 0.12495567, 0.12699805, 0.12644036, 0.12650806]])\n",
    "counts_stay = np.array([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],\n",
    "       [1961, 2199, 2200, 2500, 2362, 2729, 2415, 2587, 2794, 2674],\n",
    "       [3780, 3661, 3862, 3996, 3843, 3470, 4354, 3746, 3771, 4259],\n",
    "       [4351, 4581, 5105, 5562, 4800, 4546, 4349, 4139, 4262, 4717]])\n",
    "counts_not_stay = np.array([[4250, 4761, 5531, 6158, 7524, 8539, 9355, 9755, 8894, 8969],\n",
    "       [4046, 4233, 4856, 5180, 5126, 5093, 5017, 5588, 5052, 5124],\n",
    "       [2978, 3386, 2923, 3545, 4047, 3417, 3472, 3811, 3869, 3546],\n",
    "       [2819, 2797, 2752, 2743, 2780, 2612, 2702, 2732, 2678, 2709]])\n",
    "counts_stay_se = 0.3 / np.sqrt(counts_stay)\n",
    "counts_not_stay_se = 0.3 / np.sqrt(counts_not_stay)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create x-axis values\n",
    "sns.set_context('talk')\n",
    "x = np.linspace(0.1, 1, 10)\n",
    "\n",
    "# Create figure with 4 subplots\n",
    "fig, axes = plt.subplots(2, 2, figsize=(12, 10))\n",
    "axes = axes.flatten()\n",
    "\n",
    "# Colors for consistency\n",
    "color_palette = sns.color_palette()\n",
    "stay_color = color_palette[0]\n",
    "not_stay_color = color_palette[1]\n",
    "\n",
    "# Create each subplot\n",
    "for i in range(4):\n",
    "    # Plot not_stay data\n",
    "    if i != 0:\n",
    "        axes[i].errorbar(x, X_not_stay[i], #yerr=counts_not_stay_se[i], \n",
    "                        color=not_stay_color, label='Previous switch', \n",
    "                        capsize=5, marker='o')\n",
    "        \n",
    "        # Plot stay data\n",
    "        axes[i].errorbar(x, X_stay[i], #yerr=counts_stay_se[i], \n",
    "                        color=stay_color, label='Previous stay', \n",
    "                        capsize=5, marker='o')\n",
    "    else:\n",
    "        axes[i].errorbar(x, X_not_stay[i], #yerr=counts_not_stay_se[i], \n",
    "                        color=sns.color_palette()[2], label='First turn', \n",
    "                        capsize=5, marker='o')\n",
    "    \n",
    "    axes[i].set_xlabel('Proposal step size')\n",
    "    axes[i].set_ylabel('Value')\n",
    "    axes[i].legend()\n",
    "    axes[i].grid(True, alpha=0.3)\n",
    "    axes[i].set_title(f'Iteration {i+1}')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('rl_results.pdf', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Qualitative Results: Image Generation Attempts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Text Steering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "attempts = pd.read_csv('data/steering.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_multiple_attempts(df, indices, figsize=(24, 6)):\n",
    "    # Create a figure with num_rows rows, each containing 6 images\n",
    "    num_rows = len(indices)\n",
    "    fig, axes = plt.subplots(num_rows, 6, figsize=(figsize[0], figsize[1] * num_rows))\n",
    "    \n",
    "    # If there's only one row, wrap axes in a list to make it 2D\n",
    "    if num_rows == 1:\n",
    "        axes = axes.reshape(1, -1)\n",
    "\n",
    "    # Add a vertical line to separate goal from attempts\n",
    "    divider_ax = fig.add_axes([0.16, 0.0, 0.001, 0.99])  # Adjusted height to go from bottom to top\n",
    "    divider_ax.axvline(x=0, color='black', alpha=0.75)\n",
    "    divider_ax.axis('off')\n",
    "    \n",
    "    \n",
    "    # Loop through each row\n",
    "    for row_idx in range(num_rows):\n",
    "        start_idx = indices[row_idx]\n",
    "        # Get data for this row\n",
    "        # row_data = random_df.iloc[row_idx]  # Start from the first row\n",
    "        # row_data = random_df.iloc[-(row_idx + 1)] # Start from the last row\n",
    "        row_data = df.iloc[start_idx]\n",
    "        \n",
    "        # Add the goal image\n",
    "        response = requests.get(row_data['goal_image'])\n",
    "        goal_img = Image.open(BytesIO(response.content))\n",
    "        goal_prompt = row_data['goal_prompt']\n",
    "        axes[row_idx, 0].imshow(goal_img)\n",
    "        axes[row_idx, 0].axis('off')\n",
    "        axes[row_idx, 0].set_title('Goal Image', pad=10, fontsize=20)\n",
    "        \n",
    "        # Loop through each attempt and display the image\n",
    "        for col_idx in range(5):\n",
    "            idx = start_idx + col_idx\n",
    "            attempt = df.iloc[idx]\n",
    "            # Fetch the image from the URL\n",
    "            response = requests.get(attempt['generated_image'])\n",
    "            img = Image.open(BytesIO(response.content))\n",
    "            \n",
    "            # Display the image\n",
    "            axes[row_idx, col_idx+1].imshow(img)\n",
    "            axes[row_idx, col_idx+1].axis('off')\n",
    "            axes[row_idx, col_idx+1].set_title(f'Attempt {attempt[\"attempt\"]}', \n",
    "                                                    fontsize=20, # 18\n",
    "                                                    fontstyle='italic')\n",
    "            \n",
    "            # Add prompt text below the image\n",
    "            # prompt_text = textwrap.fill('\"' + attempt['prompt'] + '\"', width=38, break_long_words=False)\n",
    "            prompt_text = textwrap.fill('\"' + attempt['prompt'] + '\"', width=35, break_long_words=True)\n",
    "            axes[row_idx, col_idx+1].text(-0.0, -0.05, prompt_text.strip(),\n",
    "                                              ha='left', va='top',\n",
    "                                              transform=axes[row_idx, col_idx+1].transAxes,\n",
    "                                              fontsize=20)#,fontstyle='italic') # 14\\\n",
    "    plt.tight_layout()\n",
    "    # Add more bottom margin to accommodate the text\n",
    "    if num_rows != 1:\n",
    "        plt.subplots_adjust(bottom=0.3/num_rows, wspace=0.1, hspace=0.4)\n",
    "    \n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [1170]\n",
    "fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 6.75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [2515]\n",
    "fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 6.75))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [2515, 305, 735, 2260]\n",
    "fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 7.25))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Image Steering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(24, 6)):\n",
    "    # Create a figure with num_rows rows, each containing 6 images\n",
    "    num_rows = len(indices)\n",
    "    fig, axes = plt.subplots(num_rows, 6, figsize=(figsize[0], figsize[1] * num_rows))\n",
    "    \n",
    "    if num_rows == 1:\n",
    "        axes = axes.reshape(1, -1)\n",
    "\n",
    "    # Add a vertical line to separate goal from attempts\n",
    "    divider_ax = fig.add_axes([0.167, 0.0, 0.001, 0.99])\n",
    "    divider_ax.axvline(x=0, color='black', alpha=0.75)\n",
    "    divider_ax.axis('off')\n",
    "    \n",
    "\n",
    "    random_attempts = [all_attempts[i] for i in indices]\n",
    "    \n",
    "    # Loop through each row\n",
    "    for row_idx in range(num_rows):\n",
    "        row_data = random_attempts[row_idx]\n",
    "        \n",
    "        # Add the goal image\n",
    "        response = requests.get(row_data['goal_url'])\n",
    "        goal_img = Image.open(BytesIO(response.content))\n",
    "        goal_prompt = row_data['goal_prompt']\n",
    "        axes[row_idx, 0].imshow(goal_img)\n",
    "        axes[row_idx, 0].axis('off')\n",
    "        axes[row_idx, 0].set_title(f'Goal Image', pad=10, fontsize=20)\n",
    "        \n",
    "        # Get the list of attempts\n",
    "        attempts = row_data['rounds']\n",
    "\n",
    "        image_0_url = row_data['rounds'][0]['generated_images'][0]\n",
    "        response = requests.get(image_0_url)\n",
    "        img = Image.open(BytesIO(response.content))\n",
    "        axes[row_idx, 1].imshow(img)\n",
    "        axes[row_idx, 1].axis('off')\n",
    "        axes[row_idx, 1].set_title(f'Attempt 1', fontsize=20)  \n",
    "\n",
    "        prompt_text = textwrap.fill(row_data['user_prompt'], width=35, break_long_words=False)\n",
    "        axes[row_idx, 1].text(0., -0.05, prompt_text, \n",
    "                  ha='left', va='top', \n",
    "                  transform=axes[row_idx, 1].transAxes,\n",
    "                  fontsize=20)#, fontstyle='italic')\n",
    "\n",
    "        # Loop through each attempt and display the image\n",
    "        for attempt_idx, attempt in enumerate(attempts):\n",
    "            # Fetch the image from the URL\n",
    "            response = requests.get(attempt['selected_image_url'])\n",
    "            img = Image.open(BytesIO(response.content))\n",
    "            \n",
    "            # Display the image\n",
    "            axes[row_idx, attempt_idx + 2].imshow(img)\n",
    "            axes[row_idx, attempt_idx + 2].axis('off')\n",
    "            axes[row_idx, attempt_idx + 2].set_title(f'Attempt {1+attempt[\"iteration\"]}', \n",
    "                                                    fontsize=20)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.subplots_adjust(bottom=0.3/num_rows, wspace=0.1, hspace=0.4)\n",
    "    \n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ast import literal_eval\n",
    "def parse(x):\n",
    "    try:\n",
    "        data = literal_eval(x)\n",
    "        return data\n",
    "    except (ValueError, SyntaxError) as e:\n",
    "        print(f\"Error evaluating string: {e}\")\n",
    "\n",
    "df = pd.read_csv('data/img_steering.csv')\n",
    "df['all_attempts'] = df['all_attempts'].apply(parse)\n",
    "all_attempts = df['all_attempts'].tolist()\n",
    "# Only include rows where all_attempts has 2 items\n",
    "all_attempts = [item for sublist in all_attempts for item in sublist]\n",
    "len(all_attempts)\n",
    "\n",
    "df = pd.read_csv('data/img_steering_rl.csv')\n",
    "df['all_attempts'] = df['all_attempts'].apply(parse)\n",
    "all_attempts_2 = df['all_attempts'].tolist()\n",
    "# Only include rows where all_attempts has 2 items\n",
    "all_attempts += [item for sublist in all_attempts_2 for item in sublist]\n",
    "len(all_attempts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [-14, -18, -58, -1]\n",
    "fig, axes = plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(30, 7))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image Steering on Tiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('data/img_steering_rl_tiles.csv')\n",
    "df = df[pd.to_datetime(df['StartDate'], errors='coerce').notnull()]\n",
    "df = df[~pd.isnull(df['ID'])]\n",
    "df.shape\n",
    "df['all_attempts'] = df['all_attempts'].apply(parse_json_safely)\n",
    "# Only include rows where all_attempts is not None\n",
    "df = df[df['all_attempts'].notna()]\n",
    "df = df[df['all_attempts'].apply(lambda x: len(x) == 2)]\n",
    "all_attempts = df['all_attempts'].tolist()\n",
    "# Only include rows where all_attempts has 2 items\n",
    "all_attempts = [item for sublist in all_attempts for item in sublist]\n",
    "len(all_attempts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [68, 87, 2, 19]#, 22]\n",
    "fig, axes = plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(30, 7.6))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Blind Steering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\"data/steering.csv\")\n",
    "llm_baseline4 = pd.read_csv('data/blind_steering_4.csv')\n",
    "llm_baseline10 = pd.read_csv('data/blind_steering_10.csv')\n",
    "llm_baseline20 = pd.read_csv('data/blind_steering_20.csv')\n",
    "llm_baseline7 = pd.read_csv('data/blind_steering_7.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the max score for each goal image\n",
    "max_scores = data.groupby('goal_image')['dreamsim'].max()\n",
    "first_scores = data.groupby('goal_image')['dreamsim'].first()\n",
    "np.mean((max_scores - first_scores) / first_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_improvement_stats(baseline_df):\n",
    "  # For each URL, it it's in data and baseline_df, include it in the analysis\n",
    "  common_urls = set(data['goal_image']).intersection(set(baseline_df['goal_url']))\n",
    "  human_improvements = []\n",
    "  llm_improvements = []\n",
    "  for common_url in common_urls:\n",
    "      # First, get human improvement\n",
    "      relevant_data = data[data['goal_image'] == common_url]\n",
    "      original_score = relevant_data[relevant_data['attempt'] == 1]['dreamsim'].values[0].item()\n",
    "      human_improvement = relevant_data['dreamsim'].max() - original_score\n",
    "      # Now, get LLM improvement\n",
    "      relevant_llm = baseline_df[baseline_df['goal_url'] == common_url]\n",
    "      llm_improvement = max(relevant_llm['score'].max() - original_score, 0.0) # Need to do 0 because sometimes the first prompt is the best and we don't store it\n",
    "      # print(human_improvement, llm_improvement)\n",
    "      human_improvements.append(human_improvement)\n",
    "      llm_improvements.append(llm_improvement)\n",
    "\n",
    "  human_improvements = np.array(human_improvements)\n",
    "  llm_improvements = np.array(llm_improvements)\n",
    "  mean_ratio = np.mean(llm_improvements) / np.mean(human_improvements)\n",
    "  # Estimate se by bootstrapping\n",
    "  num_bootstraps = 1000\n",
    "  bootstrap_ratios = []\n",
    "  for _ in range(num_bootstraps):\n",
    "    bootstrap_inds = np.random.randint(0, len(human_improvements), size=len(human_improvements))\n",
    "    bootstrap_human = human_improvements[bootstrap_inds]\n",
    "    bootstrap_llm = llm_improvements[bootstrap_inds]\n",
    "    bootstrap_ratios.append(np.mean(bootstrap_llm) / np.mean(bootstrap_human))\n",
    "  se = np.std(bootstrap_ratios)\n",
    "  return mean_ratio, se\n",
    "\n",
    "improvement_means = [0]\n",
    "improvement_ses = [0]\n",
    "for baseline in [llm_baseline4, llm_baseline7, llm_baseline10, llm_baseline20]:\n",
    "    mean_ratio, se = get_improvement_stats(baseline)\n",
    "    improvement_means.append(mean_ratio)\n",
    "    improvement_ses.append(se)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = [0, 4, 7, 10, 20]\n",
    "ys = improvement_means\n",
    "\n",
    "# plt.figure(figsize=(8, 4))\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.plot(xs, ys, 'b-')\n",
    "# plt.scatter(xs[1], ys[1], color='red', s=100)\n",
    "# Make the second point red but all others blue\n",
    "for i in range(len(xs)):\n",
    "    # if i != 1:\n",
    "        plt.errorbar(xs[i], ys[i], yerr=improvement_ses[i], color=sns.color_palette()[0], fmt='o', capsize=5)\n",
    "    # else:\n",
    "        # plt.errorbar(xs[i], ys[i], yerr=improvement_ses[i], color='gray', fmt='o', capsize=5)\n",
    "        \n",
    "# plt.errorbar(xs, ys, yerr=improvement_ses, fmt='o', capsize=5)\n",
    "plt.axhline(y=1, color='k', linestyle='--', label='Human performance')\n",
    "# plt.axvline(x=4, color='red', linestyle='--', label='Humans have 4 attempts to reprompt')\n",
    "plt.xticks(xs)\n",
    "# Make the 4 tick red\n",
    "# plt.gca().get_xticklabels()[1].set_color(sns.color_palette()[2])\n",
    "# plt.gca().get_xticklabels()[1].set_color('gray')\n",
    "plt.ylim([-0.05, 1.05])\n",
    "plt.legend()\n",
    "plt.xlabel('Number of blind prompt rewrites')\n",
    "# plt.xlabel('Number of blind attempts')# (Note: humans have 4 attempts to reprompt)')\n",
    "# plt.text(10., -0.3, \"(Humans have 4 attempts to reprompt)\", ha='center', va='top', fontsize=12, color=sns.color_palette()[2])\n",
    "plt.text(10., -0.305, \"(Humans have 4 attempts to reprompt)\", ha='center', va='top', fontsize=15, color='dimgray')\n",
    "# plt.ylabel(\"Fraction of human\\n improvement attained by LLM\")\n",
    "plt.ylabel(\"Fraction of human\\n performance attained\")\n",
    "plt.title(\"How close is human steering to blind steering?\")\n",
    "plt.savefig('human_improvement_vs_llm_improvement.pdf', dpi=300, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Steerability vs. Producibility Frontier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_seeds = [1, 2, 3, 1e9]\n",
    "steering_scores = [0.716, 0.689, 0.680, 0.652]\n",
    "producibility_scores = [0.684, 0.689, 0.697, 0.742]\n",
    "procudibility_errs = [0.01, 0.01, 0.01, 0.01]\n",
    "steering_errs = [0.0129, 0.0127, 0.0121, 0.010]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(7, 4))\n",
    "plt.scatter(steering_scores, producibility_scores, s=100)\n",
    "\n",
    "plt.plot(steering_scores, producibility_scores, 'b--')\n",
    "plt.xlim(0.645, 0.725)\n",
    "plt.ylim(0.6726, 0.7526)\n",
    "plt.xlabel('Steerability score')\n",
    "plt.ylabel('Producibility score')\n",
    "plt.text(0.6645, 0.7425, 'Default model', ha='center', va='center', fontsize=15, color='black')\n",
    "plt.text(0.715, 0.693, 'Constrained\\n to 1 seed', ha='center', va='center', fontsize=15, color='black')\n",
    "plt.title(\"Tradeoff between steerability and producibility\")\n",
    "plt.savefig('frontier.pdf', dpi=35, bbox_inches='tight')\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
