{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6e76ce02-cc6e-4a33-9be4-c9bd7b7f4622",
   "metadata": {},
   "source": [
    "# Results analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "061bb77f-6a92-4e60-a530-65e77bf96a08",
   "metadata": {},
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "591bcd75-3cb3-4b16-a7cd-133446e26c3e",
   "metadata": {},
   "source": [
    "import json\n",
    "import logging\n",
    "import re\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from function_calling_pi.task_suites import SUITES\n",
    "from function_calling_pi.types import ChatMessage\n",
    "from matplotlib.lines import Line2D\n",
    "from pydantic import BaseModel, ConfigDict, model_validator\n",
    "from statsmodels.stats import proportion\n",
    "\n",
    "logging.getLogger('matplotlib').setLevel(logging.WARNING)\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_context(\"paper\")\n",
    "matplotlib.rcParams[\"pdf.fonttype\"] = 42\n",
    "matplotlib.rcParams[\"ps.fonttype\"] = 42"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3628a9d3-5800-42df-bf39-763cc986c6d4",
   "metadata": {},
   "source": [
    "errors = set()\n",
    "\n",
    "class TaskResults(BaseModel):\n",
    "    model_config = ConfigDict(arbitrary_types_allowed=True)\n",
    "    \n",
    "    suite_name: str\n",
    "    pipeline_name: str\n",
    "    user_task_id: str\n",
    "    injection_task_id: str | None\n",
    "    attack_type: str | None\n",
    "    injections: dict[str, str]\n",
    "    messages: list[ChatMessage]\n",
    "    error: str | None\n",
    "    utility: float\n",
    "    security: float\n",
    "    duration: float\n",
    "\n",
    "    @model_validator(mode='after')\n",
    "    def check_messages(self) -> \"Self\":\n",
    "        pipeline_name = self.pipeline_name\n",
    "        suite_name = self.suite_name\n",
    "        if self.error is not None:\n",
    "            self.security = 0\n",
    "        if len(self.messages) < 2:\n",
    "            # print(f\"Empty messages for {pipeline_name} and {suite_name} and {self.user_task_id} and {self.injection_task_id} are empty\")\n",
    "            errors.add((pipeline_name, suite_name, self.user_task_id, self.injection_task_id))\n",
    "        return self"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2136b0fb-a3b7-48c6-8d5f-79b62c566065",
   "metadata": {},
   "source": [
    "def load_task_results(pipeline_name, suite_name, user_task, attack_name, injection_task, logdir=Path(\"../runs\")):\n",
    "    path = logdir / pipeline_name / suite_name / user_task / attack_name / f\"{injection_task}.json\"\n",
    "    with path.open() as f:\n",
    "        res_dict = json.load(f)\n",
    "\n",
    "    return TaskResults(**res_dict)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a430ea39-0915-429b-b845-015f8e90bf27",
   "metadata": {},
   "source": [
    "def load_experiment_results(suite_name, pipeline_name, attack_name, logdir=Path(\"../runs\")):\n",
    "    suite = SUITES[suite_name]\n",
    "    results_without_injections = {}\n",
    "    results_with_injections = {}\n",
    "    \n",
    "    for user_task_id in suite.user_tasks.keys():\n",
    "        results_without_injections[(user_task_id, None)] = load_task_results(\n",
    "            pipeline_name, suite_name, user_task_id, \"none\", \"none\", logdir\n",
    "        )\n",
    "        if \"dos\" in attack_name:\n",
    "            injection_task_id = next(iter(suite.injection_tasks.keys()))\n",
    "            results_with_injections[(user_task_id, injection_task_id)] = load_task_results(\n",
    "                pipeline_name, suite_name, user_task_id, attack_name, injection_task_id, logdir\n",
    "            )\n",
    "        elif attack_name != \"none\":\n",
    "            for injection_task_id in suite.injection_tasks.keys():\n",
    "                results_with_injections[(user_task_id, injection_task_id)] = load_task_results(\n",
    "                    pipeline_name, suite_name, user_task_id, attack_name, injection_task_id, logdir\n",
    "                )\n",
    "    return results_without_injections, results_with_injections"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "617a5c56-10a7-4851-a81a-c6a23f9d559d",
   "metadata": {},
   "source": [
    "model_names_map = {\n",
    "    \"gpt-4o-2024-05-13\": \"GPT-4o\",\n",
    "    \"gpt-3.5-turbo-0125\": \"GPT-3.5 Turbo\",\n",
    "    \"gpt-4-turbo-2024-04-09\": \"GPT-4 Turbo\",\n",
    "    \"gpt-4-0125-preview\": \"GPT-4\",\n",
    "    \"claude-3-opus-20240229\": \"Claude Opus\",\n",
    "    \"claude-3-sonnet-20240229\": \"Claude Sonnet\",\n",
    "    \"claude-3-haiku-20240307\": \"Claude Haiku\",\n",
    "    \"command-r-plus\": \"Command-R+\",\n",
    "    \"command-r\": \"Command-R\",\n",
    "    \"meta-llama_Llama-3-70b-chat-hf\": \"Llama 3 70b\",\n",
    "    \"gemini-1.5-flash-001\": \"Gemini 1.5 Flash\",\n",
    "    \"gemini-1.5-pro-001\": \"Gemini 1.5 Pro\",\n",
    "}"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1628ca21-66be-4787-8124-c3c0fdb2e421",
   "metadata": {},
   "source": [
    "def find_information_tags(text):\n",
    "    pattern = re.compile(r'<INFORMATION>.*?</INFORMATION>', re.DOTALL)\n",
    "    matches = [(match.start(), match.end(), len(text)) for match in pattern.finditer(text)]\n",
    "    return matches\n",
    "\n",
    "find_information_tags(\"<INFORMATION>aaaaaaa</INFORMATION><INFORMATION>aaaaaaa</INFORMATION><INFORMATION>\\n\\n\\n</INFORMATION>\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "995cda86-27d2-4fe8-9467-bd56c1fe0f02",
   "metadata": {},
   "source": [
    "model_names = [\n",
    "    \"gpt-4o-2024-05-13\",\n",
    "    \"gpt-3.5-turbo-0125\",\n",
    "    \"gpt-4-turbo-2024-04-09\",\n",
    "    # \"gpt-4-0125-preview\",\n",
    "    \"claude-3-opus-20240229\",\n",
    "    \"claude-3-sonnet-20240229\",\n",
    "    # \"claude-3-haiku-20240307\",\n",
    "    \"command-r-plus\",\n",
    "    # \"command-r\",\n",
    "    # \"mistralai_Mixtral-8x7B-Instruct-v0.1\",\n",
    "    \"meta-llama_Llama-3-70b-chat-hf\",\n",
    "    \"gemini-1.5-flash-001\",\n",
    "    \"gemini-1.5-pro-001\",\n",
    "]\n",
    "\n",
    "def get_inection_position_in_tool_outputs(messages: dict) -> list[list[tuple[int, int, int]]]:\n",
    "    injection_positions = []\n",
    "    for message in messages:\n",
    "        if message[\"role\"] == \"tool\":\n",
    "            matches = find_information_tags(message[\"content\"])\n",
    "            if len(matches) == 0:\n",
    "                continue\n",
    "            injection_positions.append(matches)\n",
    "    return injection_positions\n",
    "\n",
    "\n",
    "def make_dataframe(model_names, attack):\n",
    "    rows = []\n",
    "    for model in model_names:\n",
    "        for suite in SUITES:\n",
    "            results_without, results_with = load_experiment_results(suite, model, attack)\n",
    "            for (user_task, injection_task), result in results_without.items():\n",
    "                injections_positions = get_inection_position_in_tool_outputs(result.messages)\n",
    "                injection_seen = len(injections_positions) > 0\n",
    "                times_injection_seen = sum(len(positions) for positions in injections_positions)\n",
    "                if len(injections_positions) == 1 and len(injections_positions[0]) == 1:\n",
    "                    injection_start, injection_end, output_len = injections_positions[0][0]\n",
    "                    start_injection_position = injection_start / output_len\n",
    "                    end_injection_position = injection_end / output_len\n",
    "                else:\n",
    "                    start_injection_position = None\n",
    "                    end_injection_position = None\n",
    "                rows.append({\n",
    "                    \"model\": model,\n",
    "                    \"suite\": suite,\n",
    "                    \"user_task\": user_task,\n",
    "                    \"injection_task\": injection_task,\n",
    "                    \"utility\": result.utility,\n",
    "                    \"security\": result.security,\n",
    "                    \"injection_seen\": injection_seen,\n",
    "                    \"times_injection_seen\": times_injection_seen,\n",
    "                    \"start_injection_position\": start_injection_position,\n",
    "                    \"end_injection_position\": end_injection_position,\n",
    "                })\n",
    "            for (user_task, injection_task), result in results_with.items():\n",
    "                injections_positions = get_inection_position_in_tool_outputs(result.messages)\n",
    "                injection_seen = len(injections_positions) > 0\n",
    "                times_injection_seen = sum(len(positions) for positions in injections_positions)\n",
    "                if len(injections_positions) == 1 and len(injections_positions[0]) == 1:\n",
    "                    injection_start, injection_end, output_len = injections_positions[0][0]\n",
    "                    start_injection_position = injection_start / output_len\n",
    "                    end_injection_position = injection_end / output_len\n",
    "                else:\n",
    "                    start_injection_position = None\n",
    "                    end_injection_position = None\n",
    "                rows.append({\n",
    "                    \"model\": model,\n",
    "                    \"suite\": suite,\n",
    "                    \"user_task\": user_task,\n",
    "                    \"injection_task\": injection_task,\n",
    "                    \"utility\": result.utility,\n",
    "                    \"security\": result.security,\n",
    "                    \"injection_seen\": injection_seen,\n",
    "                    \"times_injection_seen\": times_injection_seen,\n",
    "                    \"start_injection_position\": start_injection_position,\n",
    "                    \"end_injection_position\": end_injection_position,\n",
    "                })\n",
    "    return pd.DataFrame(rows)\n",
    "\n",
    "df = make_dataframe(model_names, \"none\")\n",
    "df_with = make_dataframe(model_names, \"important_instructions\")\n",
    "df_with = df_with[~df_with[\"injection_task\"].isnull()]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5c5d23f9-1a66-417b-9c0f-8b8333c7c0af",
   "metadata": {},
   "source": [
    "means = (\n",
    "    df\n",
    "    .drop(columns=[\"user_task\", \"injection_task\"])\n",
    "    # .groupby(by=[\"model\", \"suite\"])\n",
    "    # .mean()\n",
    "    .reset_index()\n",
    ")\n",
    "means[\"model\"] = means[\"model\"].map(model_names_map)\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(data=means, x=\"suite\", y=\"utility\", hue=\"model\")\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)\n",
    "plt.ylim(0, 1);\n",
    "plt.xlabel(\"\")\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/utility-results-breakdown.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "57ad9cf5-73c1-4c4a-b39d-82a2480d10aa",
   "metadata": {},
   "source": [
    "overall_means = df.groupby(by=[\"model\"]).mean([\"utility\",  \"security\"]).reset_index().sort_values(by=\"utility\")\n",
    "overall_means_with = df_with.groupby(by=[\"model\"]).mean([\"utility\",  \"security\"]).reset_index().sort_values(by=\"utility\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "efa2e209-1c30-45bb-ab52-b722a8235439",
   "metadata": {},
   "source": [
    "overall_means[\"source\"] = \"ours\"\n",
    "overall_means"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3e09ad17-8d16-40af-ad8e-7d1cc35ea47f",
   "metadata": {},
   "source": [
    "overall_means_with[\"source\"] = \"ours (attack)\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b12e9726-0614-4116-b934-5ef5d1dacfaa",
   "metadata": {},
   "source": [
    "overall_means_with"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "aa0a3453-45ae-4ae1-be79-3c224b84c12f",
   "metadata": {},
   "source": [
    "bfcl = \"Yan et al.\"\n",
    "bfcl_leaderboard = [\n",
    "    {\"model\": \"command-r-plus\", \"utility\": 0.6947, \"source\": bfcl},\n",
    "    {\"model\": \"meta-llama_Llama-3-70b-chat-hf\", \"utility\": 0.8212, \"source\": bfcl},\n",
    "    {\"model\": \"gpt-3.5-turbo-0125\", \"utility\": 0.6247, \"source\": bfcl},\n",
    "    {\"model\": \"gemini-1.5-flash-001\", \"utility\": 0.7888, \"source\": bfcl},\n",
    "    {\"model\": \"gemini-1.5-pro-001\", \"utility\": 0.8453, \"source\": bfcl},\n",
    "    {\"model\": \"gpt-4-turbo-2024-04-09\", \"utility\": 0.8124, \"source\": bfcl},\n",
    "    {\"model\": \"gpt-4o-2024-05-13\", \"utility\": 0.8135, \"source\": bfcl},\n",
    "    {\"model\": \"gpt-4-0125-preview\", \"utility\": 0.8312, \"source\": bfcl},\n",
    "    {\"model\": \"claude-3-opus-20240229\", \"utility\": 0.7647, \"source\": bfcl}, \n",
    "    {\"model\": \"claude-3-sonnet-20240229\", \"utility\": 0.5918, \"source\": bfcl}, \n",
    "    {\"model\": \"claude-3-haiku-20240307\", \"utility\": 0.5371, \"source\": bfcl}, \n",
    "]\n",
    "\n",
    "bfcl_leaderboard = [model for model in bfcl_leaderboard if model[\"model\"] in model_names]\n",
    "\n",
    "bfcl_df = pd.DataFrame(bfcl_leaderboard)\n",
    "bfcl_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8668cd40-aa27-45a8-a509-96758fd58fc9",
   "metadata": {},
   "source": [
    "plt.figure(figsize=(4, 3))\n",
    "\n",
    "ours_df = df[[\"utility\", \"model\"]].copy()\n",
    "ours_df[\"source\"] = \"Ours\"\n",
    "\n",
    "ours_df_with = df_with[[\"utility\", \"model\"]].copy()\n",
    "ours_df_with[\"source\"] = \"Ours (attack)\"\n",
    "concatenated_overall_means = pd.concat([bfcl_df, ours_df, ours_df_with]).reset_index()\n",
    "\n",
    "concatenated_overall_means[\"model\"] = concatenated_overall_means[\"model\"].map(model_names_map)\n",
    "ax = sns.barplot(data=concatenated_overall_means, x=\"model\", y=\"utility\", hue=\"source\", errorbar=('ci', 95))\n",
    "ax.legend(loc='upper center', ncol=3, columnspacing=0.9)\n",
    "plt.ylim(0, 1)\n",
    "plt.xticks(rotation=45, ha=\"right\")\n",
    "plt.xlabel(\"\")\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/challenging-benchmark.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9404a28c-bf19-4293-b023-711126543bd3",
   "metadata": {},
   "source": [
    "def make_attack_dataframe(attack_df):\n",
    "    injection_df = attack_df[~attack_df[\"injection_task\"].isnull()]\n",
    "    utility_with_means = (\n",
    "        injection_df\n",
    "        .reset_index()\n",
    "        .sort_values(by=\"utility\")\n",
    "    )\n",
    "    utility_with_means[\"ASR\"] = 1 - utility_with_means[\"utility\"]\n",
    "    injection_means = (\n",
    "        injection_df\n",
    "        .reset_index()\n",
    "        .sort_values(by=\"security\")\n",
    "    )\n",
    "    utility_with_means[\"type\"] = \"Untargeted\"\n",
    "    \n",
    "    injection_means[\"ASR\"] = utility_with_means[\"security\"]\n",
    "    injection_means[\"type\"] = \"Targeted\"\n",
    "    return pd.concat([utility_with_means, injection_means])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a0329a7e-603c-415e-8754-734b917a8839",
   "metadata": {},
   "source": [
    "attack_df = make_attack_dataframe(df_with).sort_values(by=[\"type\", \"ASR\"], ascending=[True, False])\n",
    "attack_df[\"model\"] = attack_df[\"model\"].map(model_names_map)\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(data=attack_df, x=\"model\", y=\"ASR\", hue=\"type\", hue_order=[\"Untargeted\", \"Targeted\"])\n",
    "plt.legend(loc='upper center', ncol=2)\n",
    "plt.ylim(0, 1)\n",
    "plt.xticks(rotation=45, ha=\"right\");\n",
    "plt.xlabel(\"\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1af82944-a7ed-4040-9b20-8f666ab1279e",
   "metadata": {},
   "source": [
    "# Calculate mean and confidence intervals\n",
    "def mean_confidence_interval(data, confidence=0.95):\n",
    "    x = np.array(data)\n",
    "    count = len(x)\n",
    "    mean, nobs = np.mean(x), x.sum()\n",
    "    lower, upper = proportion.proportion_confint(nobs, count)\n",
    "    if not np.isclose( mean - lower, upper - mean):\n",
    "        print(\"Error in CI computation\")\n",
    "    return mean, mean-lower\n",
    "\n",
    "\n",
    "def make_attack_dataframe_v2(attack_df):\n",
    "    non_injection_df = attack_df[attack_df[\"injection_task\"].isnull()].copy()\n",
    "    non_injection_df[\"Utility without attack\"] = non_injection_df[\"utility\"]\n",
    "    non_injection_df = non_injection_df[[\"model\", \"Utility without attack\", \"user_task\", \"suite\"]]\n",
    "    \n",
    "    injection_df = attack_df[~attack_df[\"injection_task\"].isnull()].copy()\n",
    "    injection_df[\"Utility with attack\"] = injection_df[\"utility\"]\n",
    "    injection_df[\"Targeted ASR\"] = injection_df[\"security\"]\n",
    "    injection_df = injection_df[[\"model\", \"Utility with attack\", \"Targeted ASR\", \"user_task\", \"suite\"]]\n",
    "    \n",
    "    merged_df = pd.merge(injection_df, non_injection_df, on=[\"model\", \"user_task\", \"suite\"])\n",
    "    \n",
    "    # Define a function to apply mean_confidence_interval to each group\n",
    "    def compute_conf_intervals(group):\n",
    "        return pd.Series({\n",
    "            \"Utility without attack\": mean_confidence_interval(group[\"Utility without attack\"])[0],\n",
    "            \"Utility without attack 95% CI\": mean_confidence_interval(group[\"Utility without attack\"])[1],\n",
    "            \"Utility with attack\": mean_confidence_interval(group[\"Utility with attack\"])[0],\n",
    "            \"Utility with attack 95% CI\": mean_confidence_interval(group[\"Utility with attack\"])[1],\n",
    "            \"Targeted ASR\": mean_confidence_interval(group[\"Targeted ASR\"])[0],\n",
    "            \"Targeted ASR 95% CI\": mean_confidence_interval(group[\"Targeted ASR\"])[1],\n",
    "        })\n",
    "    \n",
    "    # Group by the relevant columns and apply the function\n",
    "    final_df = merged_df.groupby([\"model\"]).apply(compute_conf_intervals, include_groups=False).reset_index()\n",
    "    \n",
    "    return final_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "19d65f17-ad30-428a-96d5-ae9ef770b9d2",
   "metadata": {},
   "source": [
    "def make_pareto_front(attack_df):\n",
    "    # Make a copy of the dataframe to avoid modifying the original one\n",
    "    df = attack_df[[\"model\", \"Utility without attack\", \"Utility with attack\"]].copy()\n",
    "    df = df.groupby(by=[\"model\"]).mean([\"Utility without attack\", \"Utility with attack\"]).reset_index()\n",
    "\n",
    "    # Sort by 'Utility without attack' descending and 'Targeted ASR' ascending\n",
    "    df = df.sort_values(by=['Utility without attack', 'Utility with attack'], ascending=[False, True])\n",
    "    \n",
    "    # Compute the Pareto front\n",
    "    pareto_front = []\n",
    "    current_max = float('-inf')\n",
    "    pareto_indices = []\n",
    "    \n",
    "    for index, row in df.iterrows():\n",
    "        if row['Utility with attack'] > current_max:\n",
    "            pareto_front.append((row['Utility without attack'], row['Utility with attack'], row['model']))\n",
    "            pareto_indices.append(index)\n",
    "            current_max = row['Utility with attack']\n",
    "    \n",
    "    # Convert the Pareto front to a dataframe for easier visualization\n",
    "    pareto_front_df = pd.DataFrame(pareto_front, columns=['Utility without attack', 'Utility with attack', 'model'])\n",
    "    \n",
    "    # Identify non-Pareto front points\n",
    "    non_pareto_df = df.drop(index=pareto_indices)\n",
    "    \n",
    "    return pareto_front_df, non_pareto_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0bffc937-31bf-4828-81ce-4e662680c57e",
   "metadata": {},
   "source": [
    "markers = {\n",
    "    'Llama 3 70b': 'o',\n",
    "    'Command-R+': 's',\n",
    "    'Gemini 1.5 Flash': \"d\",\n",
    "    'Gemini 1.5 Pro': '^',\n",
    "    'Claude Sonnet': 'v',\n",
    "    'Claude Opus': '*',\n",
    "    'GPT-3.5 Turbo': '>',\n",
    "    'GPT-4 Turbo': 'P',\n",
    "    'GPT-4o': 'X',\n",
    "}\n",
    "\n",
    "# Define colors for each model in the non_pareto_df\n",
    "colors = {\n",
    "    'Llama 3 70b': 'cyan',\n",
    "    'Command-R+': 'green',\n",
    "    'Gemini 1.5 Flash': 'red',\n",
    "    'Gemini 1.5 Pro': 'purple',\n",
    "    'Claude Sonnet': 'orange',\n",
    "    'Claude Opus': 'brown',\n",
    "    'GPT-3.5 Turbo': 'pink',\n",
    "    'GPT-4 Turbo': 'olive',\n",
    "    'GPT-4o': 'blue',\n",
    "}"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "92403ce5-6e7c-4758-99ac-18ecf90b7702",
   "metadata": {},
   "source": [
    "df_with = make_dataframe(model_names, \"important_instructions\")\n",
    "attack_df_v2 = make_attack_dataframe_v2(df_with)\n",
    "attack_df_v2[\"model\"] = attack_df_v2[\"model\"].map(model_names_map)\n",
    "utility_pareto_front_df, utility_non_pareto_df = make_pareto_front(attack_df_v2)\n",
    "\n",
    "tmp_df = attack_df_v2.copy()\n",
    "numeric_cols = attack_df_v2.select_dtypes(include=['number']).columns\n",
    "tmp_df[numeric_cols] = attack_df_v2[numeric_cols] * 100\n",
    "numeric_cols = tmp_df.select_dtypes(include=['float', 'int'])\n",
    "tmp_df[numeric_cols.columns] = numeric_cols.round(2)\n",
    "\n",
    "tmp_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e354d6b1-0973-4479-ac85-ef8ccc5f389d",
   "metadata": {},
   "source": [
    "plt.figure(figsize=(4, 3))\n",
    "\n",
    "def plot_model_utility_results(df, model, **plot_kwargs):\n",
    "    model_results = df[df['model'] == model]\n",
    "    return sns.scatterplot(\n",
    "        data=model_results,\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Utility with attack\",\n",
    "        color=colors[model],\n",
    "        marker=markers[model],\n",
    "        legend=False,\n",
    "        s=200,\n",
    "        **plot_kwargs\n",
    "    )\n",
    "\n",
    "# Plot non-Pareto front points with transparency and specific colors\n",
    "for model in utility_non_pareto_df['model'].unique():\n",
    "    plot_model_utility_results(attack_df_v2, model, alpha=0.3)\n",
    "\n",
    "# Plot Pareto front points\n",
    "for model in utility_pareto_front_df['model'].unique():\n",
    "    ax = plot_model_utility_results(attack_df_v2, model)\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=14)\n",
    "ax.tick_params(axis='y', labelsize=14)\n",
    "ax.set_xlabel(\"Utility without attack\", fontsize=14)\n",
    "ax.set_ylabel(\"Utility with attack\", fontsize=14)\n",
    "\n",
    "# Plot the diagonal reference line\n",
    "plt.plot(np.linspace(0, 1), np.linspace(0, 1), ls=\"--\", color=\"gray\")\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker=markers[model], color='w', label=model, \n",
    "           markerfacecolor=colors[model], markersize=10) \n",
    "    for model in markers\n",
    "]\n",
    "\n",
    "# plt.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=3, fontsize=10, columnspacing=0.7) \n",
    "# plt.legend(handles=legend_elements, loc='upper center', ncol=2, fontsize=9, columnspacing=0.7)\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "plt.xlim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/utility-attack-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d8d715c5-ebdd-4b4e-bd8a-32444691dbe3",
   "metadata": {},
   "source": [
    "def make_pareto_back(attack_df):\n",
    "    df = attack_df.copy()\n",
    "    df = attack_df[[\"model\", \"Utility without attack\", 'Targeted ASR']].copy()\n",
    "    df = df.groupby(by=[\"model\"]).mean([\"Utility without attack\", 'Targeted ASR']).reset_index()\n",
    "    \n",
    "    # Invert the 'Targeted ASR' to convert it to a maximization problem\n",
    "    df['Inverted Targeted ASR'] = -df['Targeted ASR']\n",
    "    \n",
    "    # Sort by 'Utility without attack' descending and 'Inverted Targeted ASR' ascending\n",
    "    df = df.sort_values(by=['Utility without attack', 'Inverted Targeted ASR'], ascending=[False, True])\n",
    "    \n",
    "    # Compute the Pareto front (back)\n",
    "    pareto_front = []\n",
    "    current_max = float('-inf')\n",
    "    pareto_indices = []\n",
    "    \n",
    "    for index, row in df.iterrows():\n",
    "        if row['Inverted Targeted ASR'] > current_max:\n",
    "            pareto_front.append((row['Utility without attack'], row['Targeted ASR'], row['model']))\n",
    "            pareto_indices.append(index)\n",
    "            current_max = row['Inverted Targeted ASR']\n",
    "    \n",
    "    # Convert the Pareto front to a dataframe for easier visualization\n",
    "    pareto_front_df = pd.DataFrame(pareto_front, columns=['Utility without attack', 'Targeted ASR', 'model'])\n",
    "    \n",
    "    # Identify non-Pareto front points\n",
    "    non_pareto_df = df.drop(index=pareto_indices)\n",
    "    return pareto_front_df, non_pareto_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7357878d-fba8-42d3-ba8d-319f6f7aa3f8",
   "metadata": {},
   "source": [
    "asr_pareto_front_df, asr_non_pareto_df = make_pareto_back(attack_df_v2)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "\n",
    "# Plot non-Pareto front points with transparency and specific colors\n",
    "for model in asr_non_pareto_df['model'].unique():\n",
    "    sns.scatterplot(\n",
    "        data=asr_non_pareto_df[asr_non_pareto_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Targeted ASR\",\n",
    "        color=colors[model],\n",
    "        marker=markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        alpha=0.3,  # Make non-Pareto points slightly transparent\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "# Plot Pareto front points\n",
    "for model in asr_pareto_front_df['model'].unique():\n",
    "    ax = sns.scatterplot(\n",
    "        data=asr_pareto_front_df[asr_pareto_front_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Targeted ASR\",\n",
    "        color=colors[model],\n",
    "        marker=markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        legend=False,\n",
    "    )\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=14)\n",
    "ax.tick_params(axis='y', labelsize=14)\n",
    "ax.set_xlabel(\"Utility without attack\", fontsize=14)\n",
    "ax.set_ylabel(\"Targeted ASR\", fontsize=14)\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker=markers[model], color='w', label=model, \n",
    "           markerfacecolor=colors[model], markersize=10) \n",
    "    for model in markers\n",
    "]\n",
    "\n",
    "# ax.legend(handles=legend_elements, ncol=3, fontsize=10, columnspacing=0.7)\n",
    "plt.legend(handles=legend_elements, loc='upper center', ncol=2, fontsize=8, columnspacing=0.7)\n",
    "\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "plt.xlim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/security-attack-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "8ee56d7d-5f09-47d7-a96c-1c3265b922e8",
   "metadata": {},
   "source": [
    "## Attacks\n",
    "\n",
    "### Overall attacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e219b711-c7a1-4221-b188-499b4c39f907",
   "metadata": {},
   "source": [
    "attacks = {\n",
    "    \"ignore_previous\": \"Ignore previous\",\n",
    "    \"important_instructions\": \"Important message\",\n",
    "    \"injecagent\": \"InjecAgent\",\n",
    "    # \"tool_knowledge\": \"Tool knowledge\",\n",
    "    \"max\": \"Max\",\n",
    "    \"direct\": \"TODO\"\n",
    "}\n",
    "\n",
    "attacks_model = \"gpt-4o-2024-05-13\""
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "cc3ca8e2-0d9e-447e-8ec6-5e382e1262ae",
   "metadata": {},
   "source": [
    "### Ablations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8670448b-ac52-45ac-bf1e-34231743dec5",
   "metadata": {},
   "source": [
    "ablation_attacks = {\n",
    "    \"important_instructions\": \"User and model names\",\n",
    "    \"important_instructions_no_user_name\": \"No user name\",\n",
    "    \"important_instructions_no_model_name\": \"No model name\",\n",
    "    \"important_instructions_no_names\": \"No names\",\n",
    "    \"important_instructions_wrong_model_name\": \"Wrong model name\",\n",
    "    \"important_instructions_wrong_user_name\": \"Wrong user name\",\n",
    "}\n",
    "\n",
    "ablation_attacks_dfs = []\n",
    "\n",
    "for attack in ablation_attacks:\n",
    "    ablation_attack_overall_df = make_dataframe([attacks_model], attack)\n",
    "    ablation_attack_means_df = make_attack_dataframe(ablation_attack_overall_df)\n",
    "    ablation_attack_means_df[\"attack\"] = attack\n",
    "    ablation_attacks_dfs.append(ablation_attack_means_df)\n",
    "\n",
    "ablation_attacks_df = pd.concat(ablation_attacks_dfs).sort_values(by=\"ASR\")\n",
    "ablation_attacks_df[\"attack\"] = ablation_attacks_df[\"attack\"].map(ablation_attacks)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(data=ablation_attacks_df, x=\"attack\", y=\"ASR\", hue=\"type\")\n",
    "plt.legend(loc='upper center', ncol=2)\n",
    "plt.xticks(rotation=45, ha=\"right\");\n",
    "plt.xlabel(\"\")\n",
    "plt.ylim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/ablations-attacks-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "f652c4ce-be7c-460c-b20d-2232bcd5dde3",
   "metadata": {},
   "source": [
    "### Max attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "c160dab7-5274-48fe-a03e-95ec456bf9df",
   "metadata": {},
   "source": [
    "attacks_model = \"gpt-4o-2024-05-13\"\n",
    "\n",
    "baseline_attacks_dfs = []\n",
    "max_attacks_dfs = []\n",
    "\n",
    "for attack in [attack for attack in ablation_attacks | attacks if attack != \"max\"]:\n",
    "    attack_overall_df = make_dataframe([attacks_model], attack)\n",
    "    attack_means_df = make_attack_dataframe(attack_overall_df)\n",
    "    max_attacks_dfs.append(attack_overall_df)\n",
    "    attack_means_df[\"attack\"] = attack\n",
    "    baseline_attacks_dfs.append(attack_means_df)\n",
    "\n",
    "overall_max_attacks_df = pd.concat(max_attacks_dfs)\n",
    "\n",
    "max_attack = overall_max_attacks_df.groupby([\"suite\", \"user_task\", \"injection_task\"]).max(\"security\").reset_index()\n",
    "max_attack[\"model\"] = attacks_model\n",
    "max_attack[\"utility\"] = overall_max_attacks_df.groupby([\"suite\", \"user_task\", \"injection_task\"]).min(\"utility\").reset_index()[\"utility\"]\n",
    "max_attack_df = make_attack_dataframe(max_attack)\n",
    "max_attack_df[\"attack\"] = \"max\"\n",
    "\n",
    "baseline_attacks_df = pd.concat([*baseline_attacks_dfs, max_attack_df]).sort_values(by=\"ASR\")\n",
    "\n",
    "attacks_to_remove = {\n",
    "    \"important_instructions_no_user_name\",\n",
    "    \"important_instructions_no_model_name\",\n",
    "    \"important_instructions_no_names\",\n",
    "    \"important_instructions_wrong_model_name\",\n",
    "    \"important_instructions_wrong_user_name\"\n",
    "}\n",
    "baseline_attacks_df = baseline_attacks_df[~baseline_attacks_df[\"attack\"].isin(attacks_to_remove)]\n",
    "baseline_attacks_df[\"attack\"] = baseline_attacks_df[\"attack\"].map(attacks)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(\n",
    "    data=baseline_attacks_df,\n",
    "    x=\"attack\",\n",
    "    y=\"ASR\",\n",
    "    hue=\"type\",\n",
    "    order=[\"TODO\",\"Ignore previous\", \"InjecAgent\", \"Important message\", \"Max\"]\n",
    ")\n",
    "new_labels = [\"TODO\", \"Ignore\\nprevious\", \"InjecAgent\", \"Important\\nmessage\", \"Max\"]\n",
    "plt.gca().set_xticklabels(new_labels)\n",
    "plt.legend(loc='upper center', ncol=3)\n",
    "plt.xlabel(\"\")\n",
    "plt.ylim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/baseline-attacks-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8bffebe3-75ac-4113-a936-86c983d6d90a",
   "metadata": {},
   "source": [
    "for attack in [attack for attack in ablation_attacks | attacks if attack != \"max\"]:\n",
    "    attack_overall_df = make_dataframe([attacks_model], attack)\n",
    "    attack_means_df = make_attack_dataframe(attack_overall_df)\n",
    "    max_attacks_dfs.append(attack_overall_df)\n",
    "    attack_means_df[\"attack\"] = attack\n",
    "    baseline_attacks_dfs.append(attack_means_df)\n",
    "\n",
    "overall_max_attacks_df = pd.concat(max_attacks_dfs)\n",
    "\n",
    "max_attack = overall_max_attacks_df.groupby([\"suite\", \"user_task\", \"injection_task\"]).max(\"security\").reset_index()\n",
    "max_attack[\"model\"] = attacks_model\n",
    "max_attack[\"utility\"] = overall_max_attacks_df.groupby([\"suite\", \"user_task\", \"injection_task\"]).min(\"utility\").reset_index()[\"utility\"]\n",
    "max_attack_df = make_attack_dataframe(max_attack)\n",
    "max_attack_df[\"attack\"] = \"max\"\n",
    "\n",
    "baseline_attacks_df = pd.concat([*baseline_attacks_dfs, max_attack_df]).sort_values(by=\"ASR\")\n",
    "\n",
    "def compute_conf_intervals(group):\n",
    "    return pd.Series({\n",
    "        \"Untargeted ASR\": 1 - mean_confidence_interval(group[\"utility\"])[0],\n",
    "        \"Untargeted ASR 95% CI\": mean_confidence_interval(group[\"utility\"])[1],\n",
    "        \"Tageted ASR\": mean_confidence_interval(group[\"security\"])[0],\n",
    "        \"Tageted ASR 95% CI\": mean_confidence_interval(group[\"security\"])[1]\n",
    "    })\n",
    "    \n",
    "# Group by the relevant columns and apply the function\n",
    "baseline_attacks_df = baseline_attacks_df.groupby([\"attack\"]).apply(compute_conf_intervals, include_groups=False).reset_index()\n",
    "baseline_attacks_df[\"attack\"] = baseline_attacks_df[\"attack\"].map(attacks | ablation_attacks)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "57cbd8ad-2fdb-436e-8c47-93de2b183126",
   "metadata": {},
   "source": [
    "tmp_df = baseline_attacks_df.copy()\n",
    "tmp_df = tmp_df[(tmp_df[\"attack\"].isin(attacks.values())) | (tmp_df[\"attack\"] == \"User and model names\")]\n",
    "numeric_cols = baseline_attacks_df.select_dtypes(include=['number']).columns\n",
    "tmp_df[numeric_cols] = baseline_attacks_df[numeric_cols] * 100\n",
    "numeric_cols = tmp_df.select_dtypes(include=['float', 'int'])\n",
    "tmp_df[numeric_cols.columns] = numeric_cols.round(2)\n",
    "tmp_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "9cce3e81-0843-4cc1-80f2-bf16adde5f71",
   "metadata": {},
   "source": [
    "## Defenses\n",
    "\n",
    "### Utility under no attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "28b03524-ba51-4b1b-b587-0d79093dfb06",
   "metadata": {},
   "source": [
    "defenses = {\n",
    "    \"\": \"No defense\",\n",
    "    \"tool_filter\": \"Tool filter\",\n",
    "    \"transformers_pi_detector\": \"PI detector\",\n",
    "    \"spotlighting_with_delimiting\": \"Delimiting\",\n",
    "    \"repeat_user_prompt\": \"Repeat prompt\",\n",
    "}\n",
    "\n",
    "defenses_markers = {\n",
    "    \"No defense\": 'o',\n",
    "    \"Tool filter\": 's',\n",
    "    \"PI detector\": 'D',\n",
    "    \"Delimiting\": '^',\n",
    "    \"Repeat prompt\": 'v'\n",
    "}\n",
    "\n",
    "# Define colors for each model in the non_pareto_df\n",
    "defenses_colors = {\n",
    "    \"No defense\": 'cyan',\n",
    "    \"Tool filter\": 'green',\n",
    "    \"PI detector\": 'red',\n",
    "    \"Delimiting\": 'purple',\n",
    "    \"Repeat prompt\": 'orange',\n",
    "}\n",
    "\n",
    "model = \"gpt-4o-2024-05-13\"\n",
    "\n",
    "pipeline_names = [\"-\".join([model, defense]).strip(\"-\") for defense in defenses]\n",
    "pipelines_map = {pipeline_name: defense_name for pipeline_name, defense_name in zip(pipeline_names, defenses.values())}\n",
    "\n",
    "defenses_df = make_dataframe(pipeline_names, \"important_instructions\")\n",
    "defenses_attack_df = make_attack_dataframe_v2(defenses_df)\n",
    "\n",
    "defenses_attack_df[\"model\"] = defenses_attack_df[\"model\"].map(pipelines_map)\n",
    "defenses_utility_pareto_front_df, defenses_utility_non_pareto_df = make_pareto_front(defenses_attack_df)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "\n",
    "# Plot non-Pareto front points with transparency and specific colors\n",
    "for model in defenses_utility_non_pareto_df['model'].unique():\n",
    "    sns.scatterplot(\n",
    "        data=defenses_utility_non_pareto_df[defenses_utility_non_pareto_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Utility with attack\",\n",
    "        color=defenses_colors[model],\n",
    "        marker=defenses_markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        alpha=0.3,  # Make non-Pareto points slightly transparent\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "# Plot Pareto front points\n",
    "for model in defenses_utility_pareto_front_df['model'].unique():\n",
    "    ax = sns.scatterplot(\n",
    "        data=defenses_utility_pareto_front_df[defenses_utility_pareto_front_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Utility with attack\",\n",
    "        color=defenses_colors[model],\n",
    "        marker=defenses_markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        legend=False,\n",
    "    )\n",
    "\n",
    "# Plot the diagonal reference line\n",
    "plt.plot(np.linspace(0, 1), np.linspace(0, 1), ls=\"--\", color=\"gray\")\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=14)\n",
    "ax.tick_params(axis='y', labelsize=14)\n",
    "ax.set_xlabel(\"Utility without attack\", fontsize=14)\n",
    "ax.set_ylabel(\"Utility with attack\", fontsize=14)\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker=defenses_markers[model], color='w', label=model, \n",
    "           markerfacecolor=defenses_colors[model], markersize=10) \n",
    "    for model in defenses_markers\n",
    "]\n",
    "\n",
    "# ax.legend(handles=legend_elements, loc='upper left', fontsize=10)\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "plt.xlim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/utility-defenses-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "72ba80a2-10a5-47e7-98c4-b090b56685eb",
   "metadata": {},
   "source": [
    "asr_defenses_pareto_front_df, asr_defenses_non_pareto_df = make_pareto_back(defenses_attack_df)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "\n",
    "# Plot non-Pareto front points with transparency and specific colors\n",
    "for model in asr_defenses_non_pareto_df['model'].unique():\n",
    "    sns.scatterplot(\n",
    "        data=asr_defenses_non_pareto_df[asr_defenses_non_pareto_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Targeted ASR\",\n",
    "        color=defenses_colors[model],\n",
    "        marker=defenses_markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        alpha=0.3,  # Make non-Pareto points slightly transparent\n",
    "        legend=False\n",
    "    )\n",
    "\n",
    "# Plot Pareto front points\n",
    "for model in asr_defenses_pareto_front_df['model'].unique():\n",
    "    ax = sns.scatterplot(\n",
    "        data=asr_defenses_pareto_front_df[asr_defenses_pareto_front_df['model'] == model],\n",
    "        x=\"Utility without attack\",\n",
    "        y=\"Targeted ASR\",\n",
    "        color=defenses_colors[model],\n",
    "        marker=defenses_markers[model],  # Note the change from 'markers' to 'marker'\n",
    "        s=200,\n",
    "        legend=False,\n",
    "    )\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=14)\n",
    "ax.tick_params(axis='y', labelsize=14)\n",
    "ax.set_xlabel(\"Utility without attack\", fontsize=14)\n",
    "ax.set_ylabel(\"Targeted ASR\", fontsize=14)\n",
    "\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], marker=defenses_markers[model], color='w', label=model, \n",
    "           markerfacecolor=defenses_colors[model], markersize=10) \n",
    "    for model in defenses_markers\n",
    "]\n",
    "\n",
    "ax.legend(handles=legend_elements, loc='upper left', fontsize=10)\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "plt.xlim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/security-defenses-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "6a6d88ac-83e3-4159-b22c-1eed161503ed",
   "metadata": {},
   "source": [
    "tmp_df = defenses_attack_df.copy()\n",
    "numeric_cols = defenses_attack_df.select_dtypes(include=['number']).columns\n",
    "tmp_df[numeric_cols] = defenses_attack_df[numeric_cols] * 100\n",
    "numeric_cols = tmp_df.select_dtypes(include=['float', 'int'])\n",
    "tmp_df[numeric_cols.columns] = numeric_cols.round(2)\n",
    "tmp_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "1cc72298-1dce-40b5-bfdf-4287065ff2fa",
   "metadata": {},
   "source": [
    "## Is a confused model more vulnerable?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "1271ef72-67da-4c47-ab4f-c22817ae303d",
   "metadata": {},
   "source": [
    "gpt_4o_df = df_with[df_with[\"model\"] == \"gpt-4o-2024-05-13\"].drop(columns=[\"model\"])\n",
    "injected_tasks_df = gpt_4o_df[~gpt_4o_df[\"injection_task\"].isna()]\n",
    "non_injected_tasks_df = gpt_4o_df[gpt_4o_df[\"injection_task\"].isna()]\n",
    "\n",
    "solved_non_injected_tasks_df = non_injected_tasks_df[non_injected_tasks_df[\"utility\"] == 1]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "0109c11a-e5f7-4bda-b2c9-d50675e1fed0",
   "metadata": {},
   "source": [
    "seen_injections_df = injected_tasks_df[injected_tasks_df[\"injection_seen\"]]\n",
    "not_seen_injections_df = injected_tasks_df[~injected_tasks_df[\"injection_seen\"]]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f56c4008-e7de-450f-b8c9-368976627f2d",
   "metadata": {},
   "source": [
    "len(not_seen_injections_df) / len(injected_tasks_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "868e33e5-f238-42f3-9b6a-b27b47bafd92",
   "metadata": {},
   "source": [
    "assert len(seen_injections_df) + len(not_seen_injections_df) == len(injected_tasks_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "5236b27c-9690-4735-a9d0-6c784f93afc2",
   "metadata": {},
   "source": [
    "utility_seen_injections_df = pd.merge(seen_injections_df, solved_non_injected_tasks_df[[\"user_task\", \"suite\"]], on=[\"suite\", \"user_task\"])\n",
    "no_utility_seen_injections_df = pd.merge(seen_injections_df, solved_non_injected_tasks_df[[\"user_task\", \"suite\"]], on=[\"suite\", \"user_task\"], how=\"outer\", indicator=True)\n",
    "no_utility_seen_injections_df = no_utility_seen_injections_df[no_utility_seen_injections_df[\"_merge\"] == \"left_only\"].drop(columns=[\"_merge\"])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "5e46ad8a-84b2-4efd-b94e-9e66cf59adfe",
   "metadata": {},
   "source": [
    "assert len(no_utility_seen_injections_df) + len(utility_seen_injections_df) == len(seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "c6b77599-5260-45d4-99f5-b62e5eef8b9b",
   "metadata": {},
   "source": [
    "f\"Utility without attack, injection is seen: {len(utility_seen_injections_df) / len(injected_tasks_df)}\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "0483dd7c-4dd3-4b6c-9caa-20ccfd071265",
   "metadata": {},
   "source": [
    "f\"No utility without attack, injection is seen: {len(no_utility_seen_injections_df) / len(injected_tasks_df)}\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "50b9cb01-984b-4604-a192-eb26eaafc37b",
   "metadata": {},
   "source": [
    "utility_no_seen_injections_df = pd.merge(not_seen_injections_df, solved_non_injected_tasks_df[[\"user_task\", \"suite\"]], on=[\"suite\", \"user_task\"])\n",
    "no_utility_no_seen_injections_df = pd.merge(not_seen_injections_df, solved_non_injected_tasks_df[[\"user_task\", \"suite\"]], on=[\"suite\", \"user_task\"], how=\"outer\", indicator=True)\n",
    "no_utility_no_seen_injections_df = no_utility_no_seen_injections_df[no_utility_no_seen_injections_df[\"_merge\"] == \"left_only\"].drop(columns=[\"_merge\"])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "665353df-b4de-4ed5-9538-c559826a3e7c",
   "metadata": {},
   "source": [
    "assert len(utility_no_seen_injections_df) + len(no_utility_no_seen_injections_df) == len(not_seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "d73c316f-53f5-46d0-9e8d-6aa63e6699d6",
   "metadata": {},
   "source": [
    "f\"Utility without attack, injection is not seen: {len(utility_no_seen_injections_df) / len(injected_tasks_df)}\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "7be7d51f-1946-4186-a346-ae5ac5fffcd3",
   "metadata": {},
   "source": [
    "f\"No utility without attack, injection is not seen: {len(no_utility_no_seen_injections_df) / len(injected_tasks_df)}\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f0488e6b-82ef-45f2-a13a-c830182b216f",
   "metadata": {},
   "source": [
    "len(utility_seen_injections_df[utility_seen_injections_df[\"security\"] == 1]) / len(utility_seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "4ecbeb56-c85a-43f3-9b8c-382b1755f203",
   "metadata": {},
   "source": [
    "len(no_utility_seen_injections_df[no_utility_seen_injections_df[\"security\"] == 1]) / len(no_utility_seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "bb33a6db-a769-473c-add3-48234064d000",
   "metadata": {},
   "source": [
    "len(utility_no_seen_injections_df[utility_no_seen_injections_df[\"security\"] == 1]) / len(utility_no_seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "0553b513-ce6b-4bd0-a323-83ee0f561010",
   "metadata": {},
   "source": [
    "len(no_utility_no_seen_injections_df[no_utility_no_seen_injections_df[\"security\"] == 1]) / len(no_utility_no_seen_injections_df)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "d105e0ed-d4a7-4361-aad3-80942f3d85ca",
   "metadata": {},
   "source": [
    "## When are injections are successful?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "c9852180-b5ee-4374-9880-1d7c2d9018b6",
   "metadata": {},
   "source": [
    "max_attack_df"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "1f8eb376-6432-468c-b9ec-a56e0b8c0127",
   "metadata": {},
   "source": [
    "injected_tasks_df = df_with[~df_with[\"injection_task\"].isnull()].copy()\n",
    "\n",
    "injected_tasks_df = injected_tasks_df[injected_tasks_df[\"model\"] == attacks_model]\n",
    "\n",
    "mean_security_by_inj_task = injected_tasks_df #.groupby([\"suite\", \"injection_task\"]).mean([\"security\"]).reset_index()\n",
    "\n",
    "mean_security_by_inj_task[\"Injection task\"] = mean_security_by_inj_task[\"injection_task\"].str.strip(\"injection_task_\").astype(int)\n",
    "mean_security_by_inj_task.loc[mean_security_by_inj_task[\"suite\"] != \"slack\", \"Injection task\"] += 1\n",
    "mean_security_by_inj_task[\"ASR\"] = mean_security_by_inj_task[\"security\"]\n",
    "\n",
    "suite_means = mean_security_by_inj_task.groupby('suite')['ASR'].mean()\n",
    "\n",
    "plt.figure(figsize=(16, 3))\n",
    "g = sns.catplot(\n",
    "    data=mean_security_by_inj_task,\n",
    "    x=\"Injection task\",\n",
    "    y=\"ASR\",\n",
    "    hue=\"suite\",\n",
    "    kind=\"bar\",\n",
    "    col=\"suite\",  # This creates separate plots for each hue value\n",
    "    col_wrap=4,  # Adjust this value to control the number of columns\n",
    "    sharey=True,  # Set to True if you want the y-axis limits to be shared across plots\n",
    "    sharex=False,\n",
    "    legend=False,\n",
    "    col_order=[\"travel\", \"workspace\", \"banking\", \"slack\"]\n",
    ")\n",
    "\n",
    "# Set font sizes for different elements\n",
    "for ax in g.axes.flat:\n",
    "    suite = ax.get_title().split(' = ')[1]\n",
    "    mean_value = suite_means[suite]\n",
    "    ax.axhline(mean_value, color='dimgray', linestyle='--')\n",
    "    ax.set_xlabel(\"Injection task\", fontsize=22)\n",
    "    ax.set_ylabel(\"ASR\", fontsize=22)\n",
    "    ax.tick_params(axis='x', labelsize=22)\n",
    "    ax.tick_params(axis='y', labelsize=22)\n",
    "\n",
    "g.set_titles(size=22)\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/successful-injections.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "6f882d5f-28bd-4d2f-bbf5-49f3dbbd8bca",
   "metadata": {},
   "source": [
    "mean_security_by_inj_task[(mean_security_by_inj_task[\"injection_task\"] == \"injection_task_5\") & (mean_security_by_inj_task[\"suite\"] == \"banking\")][\"security\"].mean()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "2788d3bf-3a7c-4960-ab02-ea414109d966",
   "metadata": {},
   "source": [
    "### What is the ASR for seen injections?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "bb6ebd06-b0bd-4d19-90f4-923f5db33066",
   "metadata": {},
   "source": [
    "seen_injections_df = injected_tasks_df[injected_tasks_df[\"injection_seen\"]]\n",
    "\n",
    "attack_df = make_attack_dataframe(seen_injections_df).sort_values(by=[\"type\", \"ASR\"], ascending=[True, False])\n",
    "attack_df[\"model\"] = attack_df[\"model\"].map(model_names_map)\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(data=attack_df, x=\"model\", y=\"ASR\", hue=\"type\", hue_order=[\"Untargeted\", \"Targeted\"])\n",
    "plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)\n",
    "plt.ylim(0, 1)\n",
    "plt.xticks(rotation=45, ha=\"right\");\n",
    "plt.xlabel(\"\")\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/seen-injection-attack-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "81d7a986-079d-4d20-9c49-da1102003503",
   "metadata": {},
   "source": [
    "## Are injections that are seen more time more successful?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "8c72c6da-c200-4876-93e0-10679f1b76c9",
   "metadata": {},
   "source": [
    "gpt_4o_df = df_with[df_with[\"model\"] == \"gpt-4o-2024-05-13\"].drop(columns=[\"model\"])\n",
    "injected_tasks_df = gpt_4o_df[~gpt_4o_df[\"injection_task\"].isna()]\n",
    "\n",
    "asr_by_times_injection_seen = injected_tasks_df.groupby(\"times_injection_seen\").mean(\"security\").reset_index()\n",
    "asr_by_times_injection_seen[\"ASR\"] = asr_by_times_injection_seen[\"security\"]\n",
    "asr_by_times_injection_seen[\"Times an injection is seen\"] = asr_by_times_injection_seen[\"times_injection_seen\"]\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.lineplot(data=asr_by_times_injection_seen, x=\"Times an injection is seen\", y=\"ASR\")\n",
    "plt.ylim(0, 1)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "d15bd16f-2001-4bf6-9e17-c3ff15cc404f",
   "metadata": {},
   "source": [
    "# Cases when injection is seen once see position in text vs ASR\n",
    "one_injection_cases = injected_tasks_df[~injected_tasks_df[\"end_injection_position\"].isnull()]\n",
    "len(one_injection_cases)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "aa7ed949-417b-4339-bebc-ccb9cc60f61b",
   "metadata": {},
   "source": [
    "# https://stackoverflow.com/questions/65974938/matplotlib-sns-plot-histogram-but-with-average-of-another-variable-on-the-y-axi\n",
    "a = one_injection_cases[\"end_injection_position\"]\n",
    "b = one_injection_cases[\"security\"]\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "bins = np.linspace(0, 1 + 1e-12, 7) # 10 bins, so 11 bin boundaries\n",
    "c = np.digitize(a, bins)\n",
    "\n",
    "means = []\n",
    "conf_intervals = []\n",
    "for i in range(1, len(bins)):\n",
    "    data = b[c == i]\n",
    "    if len(data) > 0:  # ensure there is data in the bin\n",
    "        mean, ci = mean_confidence_interval(data)\n",
    "        means.append(mean)\n",
    "        conf_intervals.append(ci)\n",
    "    else:\n",
    "        means.append(np.nan)\n",
    "        conf_intervals.append(0)\n",
    "\n",
    "plt.bar(bins[:-1], means, yerr=conf_intervals, width=bins[1] - bins[0], align='edge')\n",
    "plt.xticks(np.linspace(0, 1 + 1e-12, 11))\n",
    "plt.xlim((0, 1.01))\n",
    "plt.ylim((0, 1))\n",
    "plt.margins(x=0.02) # smaller margins\n",
    "plt.xlabel(\"Relative position of the injection\")\n",
    "plt.ylabel(\"ASR\")\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/injection-position-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "8b8dc0b9-5592-4038-a8b9-3c862c31bb09",
   "metadata": {},
   "source": [
    "min(a)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "7e30da72-014f-4e78-a4b7-bb7f4cfa6670",
   "metadata": {},
   "source": [
    "a = one_injection_cases[\"start_injection_position\"]\n",
    "b = one_injection_cases[\"security\"]\n",
    "\n",
    "bins = np.linspace(0, 1 + 1e-12, 6) # 10 bins, so 11 bin boundaries\n",
    "c = np.digitize(a, bins)\n",
    "plt.bar(bins[:-1], [np.mean(b[c == i]) for i in range(1, len(bins))],\n",
    "        width=bins[1] - bins[0], align='edge')\n",
    "plt.xticks(bins)\n",
    "plt.margins(x=0.02) # smaller margins"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "84c8296d-fbf1-4360-9bc7-f8f26d260eb8",
   "metadata": {},
   "source": [
    "## How many tasks can't be defended with tool planning?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "477d579e-421c-47d8-affb-70649246cb89",
   "metadata": {},
   "source": [
    "count_per_suite = defaultdict(int)\n",
    "tot_tasks_per_suite = {}\n",
    "\n",
    "for suite_name, suite in SUITES.items():\n",
    "    suite_env = suite.load_and_inject_default_environment({})\n",
    "    tot_tasks_per_suite[suite_name] = len(suite.user_tasks) * len(suite.injection_tasks)\n",
    "    for user_task in suite.user_tasks.values():\n",
    "        user_task_tools = {tool_call[\"function\"] for tool_call in user_task.ground_truth(suite_env)}\n",
    "        for injection_task in suite.injection_tasks.values():\n",
    "            injection_task_tools = {tool_call[\"function\"] for tool_call in injection_task.ground_truth(suite_env)}\n",
    "            count_per_suite[suite_name] += injection_task_tools.issubset(user_task_tools)\n",
    "\n",
    "dict(count_per_suite), tot_tasks_per_suite"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "e151abc6-ebc3-4506-9da6-405caa6f4964",
   "metadata": {},
   "source": [
    "fractions = {suite: count / tot for (suite, count), (suite, tot) in zip(count_per_suite.items(), tot_tasks_per_suite.items())}\n",
    "fractions"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "cfbcdc56-a579-4b11-ab5c-402ca8d6abf2",
   "metadata": {},
   "source": [
    "sum(count_per_suite.values()) / sum(tot_tasks_per_suite.values())"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "ecc7e029-8efa-4ad9-8e2e-71c68d5b5b09",
   "metadata": {},
   "source": [
    "### Where do agents struggle?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "c6a3ca4f-bf7d-43c5-878b-267690f489a2",
   "metadata": {},
   "source": [
    "agents_to_consider = {\n",
    "    \"claude-3-opus-20240229\",\n",
    "    \"claude-3-sonnet-20240229\",\n",
    "    \"gemini-1.5-pro-001\",\n",
    "    \"gpt-4-turbo-2024-04-09\",\n",
    "    \"gpt-4o-2024-05-13\",\n",
    "}\n",
    "\n",
    "df_to_consider = df[df[\"model\"].isin(agents_to_consider)]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "bc1df30a-e9b8-43f8-bc3a-2dff7eda67cf",
   "metadata": {},
   "source": [
    "df_to_consider.groupby([\"suite\", \"model\"]).mean(\"utility\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "731179f9-15a7-48eb-8996-c50cdb2ab728",
   "metadata": {},
   "source": [
    "df_to_consider.groupby(\"suite\").mean(\"utility\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "e482feeb-2d2e-4b00-ab3a-062a621c05e1",
   "metadata": {},
   "source": [
    "df_to_consider.groupby(\"model\").mean(\"utility\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "1e9dbaeb-ad98-47d9-9c87-b16b367580ce",
   "metadata": {},
   "source": [
    "df_with.groupby(\"model\").mean(\"security\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "3bc8c269-a263-4657-95f0-3ba79e5d61d8",
   "metadata": {},
   "source": [
    "## DoS Attacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "2a17355a-f41f-4502-bc61-80f895af3dae",
   "metadata": {},
   "source": [
    "dos_model = \"gpt-4o-2024-05-13\"\n",
    "\n",
    "dos_attack_results = []\n",
    "\n",
    "dos_attacks = {\n",
    "    \"none\": \"No attack\",\n",
    "    \"dos\": \"Simple attack\",\n",
    "    \"swearwords_dos\": \"Swear words\",\n",
    "    \"captcha_dos\": \"Captcha\",\n",
    "    \"offensive_email_dos\": \"Offensive email\",\n",
    "    \"important_instructions\": \"Important instructions\",\n",
    "    \"felony_dos\": \"Felony\"\n",
    "}\n",
    "\n",
    "for attack in dos_attacks:\n",
    "    dos_attack_df = make_dataframe([dos_model], attack)\n",
    "    dos_attack_means_df = dos_attack_df# .groupby(by=[\"model\", \"suite\"]).mean([\"utility\",  \"security\"]).reset_index().sort_values(by=\"security\")\n",
    "    dos_attack_means_df[\"attack\"] = dos_attacks[attack]\n",
    "    dos_attack_results.append(dos_attack_means_df)\n",
    "\n",
    "plt.figure(figsize=(4, 3))\n",
    "sns.barplot(data=pd.concat(dos_attack_results), x=\"suite\", y=\"utility\", hue=\"attack\")\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)\n",
    "plt.ylim(0, 1)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/dos-attack-results.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "56b0d19d-01f0-4ec3-a8cb-e46ee36a9c8d",
   "metadata": {},
   "source": [
    "## Attacker control"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "08e3e0e1-2372-4a3e-b0ad-df40e8c60ed5",
   "metadata": {},
   "source": [
    "# computed on https://platform.openai.com/tokenizer with GPT-4/3.5\n",
    "tokens_per_suite = {\n",
    "    \"banking\": 682,\n",
    "    \"slack\": 932,\n",
    "    \"travel\": 5819,\n",
    "    \"workspace\": 3023 + 6997 + 4268\n",
    "}\n",
    "\n",
    "np.mean(list(tokens_per_suite.values()))"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "be1e4a5a-ce41-4e0e-b554-cbc5ae48246a",
   "metadata": {},
   "source": [
    "tokens_per_suite"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "a64040a3-a262-4784-af2a-1309628c9afc",
   "metadata": {},
   "source": [
    "import tiktoken\n",
    "\n",
    "pipeline_name = attacks_model\n",
    "attack_name = \"important_instructions\"\n",
    "enc = tiktoken.encoding_for_model(\"gpt-4\")\n",
    "\n",
    "suites_output_tokens = {}\n",
    "mean_injection_length_per_suite = {}\n",
    "\n",
    "\n",
    "for suite in SUITES:\n",
    "    results_without, results_with = load_experiment_results(suite, pipeline_name, attack_name)\n",
    "    tools_contents = [message[\"content\"] for result in results_with.values() for message in result.messages if message[\"role\"] == \"tool\" and \"<INFORMATION>\" in message[\"content\"]]\n",
    "    injections = [injection for result in results_with.values() for injection in result.injections.values()]\n",
    "    tools_contents_tokens = [len(enc.encode(content)) for content in tools_contents]\n",
    "    injections_tokens = [len(enc.encode(injection)) for injection in injections]\n",
    "    suites_output_tokens[suite] = np.mean(tools_contents_tokens)\n",
    "    mean_injection_length_per_suite[suite] = np.mean(injections_tokens)\n",
    "\n",
    "suites_output_tokens"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "f063d56d-ea06-42b1-bc00-1375abd126d4",
   "metadata": {},
   "source": [
    "mean_injection_length_per_suite"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "a110eece-1195-4a7c-850f-9c5f5ae26bcc",
   "metadata": {},
   "source": [
    "control = {suite: inj / out for (suite, inj), (suite, out) in zip(mean_injection_length_per_suite.items(), suites_output_tokens.items())}"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "7bdf9d21-7611-4246-83b1-cfc5572bb5f9",
   "metadata": {},
   "source": [
    "plt.figure(figsize=(4, 3))\n",
    "ax = sns.barplot(control)\n",
    "plt.ylim((0, 1))\n",
    "plt.ylabel(\"Fraction controlled by the attacker\")\n",
    "\n",
    "ax.set_ylabel(\"Fraction controlled by the attacker\", fontsize=10)\n",
    "ax.set_xlabel(\"Suite\", fontsize=10)\n",
    "ax.tick_params(axis='x', labelsize=10)\n",
    "ax.tick_params(axis='y', labelsize=10)\n",
    "plt.savefig(\"/Users/edoardo/Downloads/pi-plots/attacker-control.pdf\", bbox_inches='tight')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "35c19496-cdd5-4c69-85f1-910de6741830",
   "metadata": {},
   "source": [
    "costs = {\n",
    "    \"workspace\": 1.83, \n",
    "    \"travel\": 1.52,\n",
    "    \"slack\": 0.52,\n",
    "    \"banking\": 0.23,\n",
    "}"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "dbf94014-9e4a-4e9c-aa07-9e366ec32782",
   "metadata": {},
   "source": [
    "total_cost = 0\n",
    "\n",
    "for suite_name, cost in costs.items():\n",
    "    total_cost += cost * len(SUITES[suite_name].injection_tasks)\n",
    "\n",
    "total_cost"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6adeef3c-9a2d-407b-802c-7816185cac85",
   "metadata": {},
   "source": [],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3afd7267-a9f8-4a9e-b502-3ac18445483e",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
