{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [\n",
    "    \"gpt-4-turbo\",\n",
    "    \"gpt-3.5-turbo\",\n",
    "    \"together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\",\n",
    "    \"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n",
    "]\n",
    "tags = [\n",
    "    \"benchmark_gpt-4-turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "    \"benchmark_gpt-3.5-turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "    \"benchmark_together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "    \"benchmark_together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Models mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "javascript"
    }
   },
   "outputs": [],
   "source": [
    "models_mapping = {\n",
    "    \"gpt-4-turbo\": \"GPT-4-turbo\",\n",
    "    \"gpt-3.5-turbo\": \"GPT-3.5-turbo\",\n",
    "    \"together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo\": \"Llama3.1-405B\",\n",
    "    \"together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\": \"Llama3.1-70B\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sec 6.1: main results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get performance and ratio data\n",
    "from sotopia.database import EpisodeLog\n",
    "\n",
    "from haicosystem.protocols import HaiEnvironmentProfile\n",
    "from haicosystem.utils import get_avg_reward\n",
    "\n",
    "performance_data = {}\n",
    "ratio_data = {}\n",
    "model_rewards = {}\n",
    "all_episodes = []\n",
    "for model, tag in zip(models, tags):\n",
    "    episodes = EpisodeLog.find(EpisodeLog.tag == tag).all()\n",
    "    avg_rewards = get_avg_reward(episodes, model)  # type: ignore\n",
    "    binary_avg_rewards = get_avg_reward(episodes, model, binary=True)  # type: ignore\n",
    "    performance_data[model] = avg_rewards\n",
    "    ratio_data[model] = binary_avg_rewards\n",
    "    all_episodes.extend(episodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_latex_table(performance, models, wrap=True):\n",
    "    # Define the table dimensions and labels\n",
    "    dimensions = {\n",
    "        \"targeted_safety_risks\": \"Targeted Safety Risks & [-10, 0]\",\n",
    "        \"system_and_operational_risks\": \"System and Operational Risks & [-10, 0]\",\n",
    "        \"content_safety_risks\": \"Content Safety Risks & [-10, 0]\",\n",
    "        \"societal_risks\": \"Societal Risks & [-10, 0]\",\n",
    "        \"legal_and_rights_related_risks\": \"Legal and Rights-Related Risks & [-10, 0]\",\n",
    "        \"efficiency\": \"Efficiency & [0, 10]\",\n",
    "        \"goal\": \"Goal Completion & [0, 10]\",\n",
    "    }\n",
    "    # Start building the LaTeX code\n",
    "    # Start building the LaTeX code\n",
    "    if wrap:\n",
    "        latex_code = \"\"\"\n",
    "\\\\begin{wraptable}[13]{r}{8.7cm}\n",
    "\\\\small\n",
    "\\\\vspace{-10pt}\n",
    "\\\\centering\n",
    "\"\"\"\n",
    "    else:\n",
    "        latex_code = \"\"\"\n",
    "\\\\begin{table}[h]\n",
    "\\\\small\n",
    "\\\\centering\n",
    "\"\"\"\n",
    "    latex_code += (\n",
    "        \"    \\\\begin{tabularx}{8.7cm}{@{\\\\hspace{10pt}}\"\n",
    "        + \"r\" * (len(models) + 2)\n",
    "        + \"@{\\\\hspace{6pt}}}\\n\"\n",
    "    )\n",
    "    latex_code += \"    \\\\toprule\\n\"\n",
    "    latex_code += \"         Dimension & Range \"\n",
    "\n",
    "    # Add model headers to the table\n",
    "    for model in models:\n",
    "        latex_code += f\"& {models_mapping[model]} \"\n",
    "    latex_code += \"\\\\\\\\ \\\\midrule\\n\"\n",
    "\n",
    "    # Populate the table with data\n",
    "    for dim_key, dim_label in dimensions.items():\n",
    "        latex_code += f\"         {dim_label} \"\n",
    "        for model in models:\n",
    "            if model in performance and dim_key in performance[model]:\n",
    "                value = performance[model][dim_key][0]\n",
    "                latex_code += f\"& {value:.2f} \"\n",
    "            else:\n",
    "                latex_code += \"& \"\n",
    "        latex_code += \"\\\\\\\\ \\n\"\n",
    "\n",
    "    # Close the table\n",
    "    latex_code += \"\"\"    \\\\bottomrule\n",
    "    \\\\end{tabularx}\n",
    "    \\\\vspace{-5pt}\n",
    "    \\\\caption{The aggregated performance of each model by averaging across different episodes.}\n",
    "    \\\\label{tab:model_safety_evaluation}\n",
    "\"\"\"\n",
    "    if wrap:\n",
    "        latex_code += \"\\\\end{wraptable}\\n\"\n",
    "    else:\n",
    "        latex_code += \"\\\\end{table}\\n\"\n",
    "    return latex_code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latex_code = generate_latex_table(\n",
    "    performance=performance_data,\n",
    "    models=models,\n",
    "    wrap=False,\n",
    ")\n",
    "print(latex_code)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Risk ratio graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Tuple\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def generate_bar_graph(\n",
    "    ratio_data: Dict[str, Dict[str, Tuple[float, float]]], save_path: str = \"\"\n",
    ") -> None:\n",
    "    # Define the categories to include in the graph\n",
    "    categories = [\n",
    "        \"overall_score\",\n",
    "        \"targeted_safety_risks\",\n",
    "        \"system_and_operational_risks\",\n",
    "        \"content_safety_risks\",\n",
    "        \"societal_risks\",\n",
    "        \"legal_and_rights_related_risks\",\n",
    "    ]\n",
    "    category_acronyms = {\n",
    "        \"overall_score\": \"Overall\",\n",
    "        \"targeted_safety_risks\": \"Targ\",\n",
    "        \"system_and_operational_risks\": \"Syst\",\n",
    "        \"content_safety_risks\": \"Cont\",\n",
    "        \"societal_risks\": \"Soc\",\n",
    "        \"legal_and_rights_related_risks\": \"Legal\",\n",
    "    }\n",
    "\n",
    "    # Initialize data containers for plotting\n",
    "    models = list(ratio_data.keys())\n",
    "    number_of_models = len(models)\n",
    "    category_values = {category: [] for category in categories}\n",
    "\n",
    "    # Extract the data\n",
    "    for model in models:\n",
    "        for category in categories:\n",
    "            value = ratio_data[model].get(category, (0.0, 0.0))[0]\n",
    "            category_values[category].append(value)\n",
    "\n",
    "    # Set up Seaborn's style\n",
    "    sns.set(style=\"whitegrid\")\n",
    "\n",
    "    # Create a new figure\n",
    "    fig, ax = plt.subplots(figsize=(12, 5))  # Smaller figure size\n",
    "\n",
    "    # Define color palette (blues)\n",
    "    palette = sns.color_palette(\"Blues\", len(models))\n",
    "\n",
    "    # Plot bars for each model\n",
    "    bar_width = 0.2  # Width of the bars\n",
    "    index = range(len(categories))  # X locations for the groups\n",
    "\n",
    "    for i, model in enumerate(models):\n",
    "        bars = ax.bar(\n",
    "            [p + i * bar_width for p in index],\n",
    "            [category_values[cat][i] for cat in categories],\n",
    "            bar_width,\n",
    "            label=models_mapping[model],\n",
    "            color=palette[i],\n",
    "        )\n",
    "        # Add numbers on top of bars\n",
    "        for bar in bars:\n",
    "            yval = bar.get_height()\n",
    "            ax.text(\n",
    "                bar.get_x() + bar.get_width() / 2,\n",
    "                yval + 0.01,\n",
    "                f\"{yval:.2f}\",\n",
    "                ha=\"center\",\n",
    "                va=\"bottom\",\n",
    "                fontsize=10,\n",
    "            )\n",
    "\n",
    "    # Labels, title, and legend\n",
    "    ax.set_ylabel(\"Risk Ratio\", fontsize=12, fontweight=\"bold\")\n",
    "    ax.set_xticks([p + (bar_width * (len(models) - 1) / 2) for p in index])\n",
    "    ax.set_xticklabels([category_acronyms[cat] for cat in categories], fontsize=12)\n",
    "    ax.legend(\n",
    "        title=\"Models\",\n",
    "        title_fontsize=\"13\",\n",
    "        fontsize=\"11\",\n",
    "        loc=\"best\",\n",
    "        frameon=True,\n",
    "        fancybox=True,\n",
    "        shadow=True,\n",
    "    )\n",
    "    sns.despine()\n",
    "    # Improve visual spacing\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_bar_graph(\n",
    "    ratio_data=ratio_data,\n",
    "    save_path=\"/Users/xuhuizhou/Projects/papers/ICLR2025-HAICosystem/figures/risk_ratio_graph.pdf\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison across models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "from sotopia.database import EpisodeLog\n",
    "\n",
    "from haicosystem.protocols import HaiEnvironmentProfile\n",
    "\n",
    "\n",
    "def compare_across_models(tags: list[str]) -> list[list[str]]:\n",
    "    assert len(tags) == 2, \"The function only supports two tags for comparison.\"\n",
    "    # Get the performance data for each tag\n",
    "    episodes_with_tag1 = EpisodeLog.find(EpisodeLog.tag == tags[0]).all()\n",
    "    episodes_with_tag2 = EpisodeLog.find(EpisodeLog.tag == tags[1]).all()\n",
    "    env_agent_combo = defaultdict(list)\n",
    "    for episode in episodes_with_tag1:\n",
    "        key_string = \";\".join(\n",
    "            [episode.environment, episode.agents[0], episode.agents[1]]\n",
    "        )\n",
    "        env_agent_combo[key_string].append(episode)\n",
    "    for episode in episodes_with_tag2:\n",
    "        key_string = \";\".join(\n",
    "            [episode.environment, episode.agents[0], episode.agents[1]]\n",
    "        )\n",
    "        env_agent_combo[key_string].append(episode)\n",
    "\n",
    "    # Compare the performance of the two tags\n",
    "    comparison_results = []\n",
    "    for key, episodes in env_agent_combo.items():\n",
    "        environment = HaiEnvironmentProfile.get(key.split(\";\")[0])\n",
    "        codename = environment.codename\n",
    "        if len(episodes) == 2:\n",
    "            episode1, episode2 = episodes\n",
    "            performance1 = episode1.rewards[1][1][\"targeted_safety_risks\"]\n",
    "            performance2 = episode2.rewards[1][1][\"targeted_safety_risks\"]\n",
    "            comparison_results.append([codename, performance1, performance2])\n",
    "    return comparison_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import rich\n",
    "\n",
    "comparison = compare_across_models(\n",
    "    tags=[\n",
    "        \"benchmark_together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "        \"benchmark_together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo_gpt-4o-2024-08-06_gpt-4o-2024-08-06_haicosystem_trial2\",\n",
    "    ]\n",
    ")\n",
    "rich.print(comparison)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### safty and goal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "episodes_list = [\n",
    "    EpisodeLog.find(EpisodeLog.tag == tags[2]).all(),\n",
    "    EpisodeLog.find(EpisodeLog.tag == tags[3]).all(),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## a graph group safety and goal as x and y aixs aross different scenarios\n",
    "from collections import Counter\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from sotopia.database import EpisodeLog\n",
    "\n",
    "from haicosystem.protocols import HaiEnvironmentProfile\n",
    "\n",
    "\n",
    "def plot_safety_goal_scatter(\n",
    "    episodes_list: list[list[EpisodeLog]],\n",
    "    save_path: str = None,\n",
    "    if_show_codename: bool = False,\n",
    "    show_intent: bool = False,\n",
    "    consider_difficulty: bool = False,\n",
    "    do_regression: bool = False,\n",
    "    do_zoom_seperate: bool = False,\n",
    ") -> None:\n",
    "    # Extract safety and goal scores from episodes\n",
    "    safety_scores = {}\n",
    "    goal_scores = {}\n",
    "    env_codenames = {}\n",
    "    user_intent = {}\n",
    "    models = []\n",
    "    avg_safety_scores = {}\n",
    "    avg_goal_scores_list = {}\n",
    "    for index, episodes in enumerate(episodes_list):\n",
    "        for episode in episodes:\n",
    "            env = episode.environment\n",
    "            model = episode.models[2]\n",
    "            key = (model, env)\n",
    "            if key not in safety_scores:\n",
    "                safety_scores[key] = []\n",
    "                goal_scores[key] = []\n",
    "            safety_scores[key].append(episode.rewards[1][1][\"targeted_safety_risks\"])\n",
    "            goal_scores[key].append(episode.rewards[1][1][\"goal\"])\n",
    "            models.append(model)\n",
    "            if index == 0:\n",
    "                env_profile = HaiEnvironmentProfile.get(env)\n",
    "                env_codenames[env] = env_profile.codename\n",
    "                user_intent[env] = env_profile.agent_intent_labels[0]\n",
    "\n",
    "    avg_safety_scores = {\n",
    "        key: sum(scores) / len(scores) for key, scores in safety_scores.items()\n",
    "    }\n",
    "    avg_goal_scores = {\n",
    "        key: sum(scores) / len(scores) for key, scores in goal_scores.items()\n",
    "    }\n",
    "    # Count occurrences of (safety_score, goal_score) pairs\n",
    "    score_pairs = list(zip(avg_safety_scores.values(), avg_goal_scores.values()))\n",
    "    pair_counts = Counter(score_pairs)\n",
    "\n",
    "    # Create a scatter plot using seaborn\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)\n",
    "    models_set = set(models)\n",
    "    print(models_set)\n",
    "    assert len(models_set) == 2, \"The function only supports two models for comparison.\"\n",
    "    models_list = list(models_set)\n",
    "    if show_intent:\n",
    "        model_colors = {\n",
    "            models_list[0]: {\"benign\": \"#1f77b4\", \"malicious\": \"#d62728\"},  # blue, red\n",
    "            models_list[1]: {\n",
    "                \"benign\": \"#2ca02c\",\n",
    "                \"malicious\": \"#ff7f0e\",\n",
    "            },  # green, orange\n",
    "            # Add more models and their corresponding colors as needed\n",
    "        }\n",
    "    else:\n",
    "        model_colors = {\n",
    "            models_list[0]: {\"benign\": \"#339af0\", \"malicious\": \"#339af0\"},  # blue, red\n",
    "            models_list[1]: {\n",
    "                \"benign\": \"#22b8cf\",\n",
    "                \"malicious\": \"#22b8cf\",\n",
    "            },  # green, orange\n",
    "            # Add more models and their corresponding colors as needed\n",
    "        }\n",
    "    if consider_difficulty:\n",
    "        env_differences = {}\n",
    "        for env in env_codenames.keys():\n",
    "            model_scores = [\n",
    "                (model, avg_safety_scores[(model, env)], avg_goal_scores[(model, env)])\n",
    "                for model in models_list\n",
    "                if (model, env) in avg_safety_scores and (model, env) in avg_goal_scores\n",
    "            ]\n",
    "            if len(model_scores) == 2:\n",
    "                model1, safety1, goal1 = model_scores[0]\n",
    "                model2, safety2, goal2 = model_scores[1]\n",
    "                safety_diff = abs(safety1 - safety2)\n",
    "                goal_diff = abs(goal1 - goal2)\n",
    "                env_differences[env] = (safety_diff < 2) and (goal_diff < 2)\n",
    "    data = []\n",
    "    for key, (safety, goal) in zip(avg_safety_scores.keys(), score_pairs):\n",
    "        model = key[\n",
    "            0\n",
    "        ]  # Assuming the model is the same for all episodes in this context\n",
    "        color = model_colors[model][user_intent[key[1]]]\n",
    "        if consider_difficulty and env_differences[key[1]]:\n",
    "            color = \"#8879de\"\n",
    "        data.append(\n",
    "            {\n",
    "                \"model\": model,\n",
    "                \"safety\": safety,\n",
    "                \"goal\": goal,\n",
    "                \"size\": pair_counts[(safety, goal)],\n",
    "                \"color\": color,\n",
    "                \"intent\": user_intent[key[1]] if show_intent else \"\",\n",
    "                \"codename\": env_codenames[key[1]] if if_show_codename else \"\",\n",
    "            }\n",
    "        )\n",
    "\n",
    "    df = pd.DataFrame(data)\n",
    "    custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False}\n",
    "    sns.set_theme(style=\"whitegrid\", rc=custom_params)\n",
    "\n",
    "    for ax, model in zip(axes, models_list):\n",
    "        model_df = df[df[\"model\"] == model]\n",
    "        scatter_plot = sns.scatterplot(\n",
    "            ax=ax,\n",
    "            data=model_df,\n",
    "            x=\"safety\",\n",
    "            y=\"goal\",\n",
    "            size=\"size\",\n",
    "            sizes=(200, 800),\n",
    "            hue=\"color\",\n",
    "            palette=model_df[\"color\"].unique(),\n",
    "            alpha=0.5,\n",
    "            edgecolor=\"w\",\n",
    "            linewidth=0.5,\n",
    "            legend=False,\n",
    "        )\n",
    "\n",
    "        # Perform regression analysis\n",
    "        if show_intent:\n",
    "            # do intent wise regression\n",
    "            for intent in model_colors[model]:\n",
    "                intent_df = model_df[model_df[\"intent\"] == intent]\n",
    "                reg_plot = sns.regplot(\n",
    "                    ax=ax,\n",
    "                    data=intent_df,\n",
    "                    x=\"safety\",\n",
    "                    y=\"goal\",\n",
    "                    scatter=False,\n",
    "                    color=model_colors[model][intent],\n",
    "                    line_kws={\"linewidth\": 1, \"alpha\": 0.7},\n",
    "                )\n",
    "                correlation = intent_df[\"safety\"].corr(intent_df[\"goal\"])\n",
    "                if intent == \"malicious\":\n",
    "                    xy_position = (0.95, 0.05)\n",
    "                elif intent == \"benign\":\n",
    "                    xy_position = (0.95, 0.15)\n",
    "                ax.annotate(\n",
    "                    f\"{intent} Correlation: {correlation:.2f}\",\n",
    "                    xy=xy_position,\n",
    "                    xycoords=\"axes fraction\",\n",
    "                    fontsize=12,\n",
    "                    ha=\"right\",\n",
    "                    va=\"bottom\",\n",
    "                    color=model_colors[model][intent],\n",
    "                )\n",
    "        else:\n",
    "            if do_regression:\n",
    "                reg_plot = sns.regplot(\n",
    "                    ax=ax,\n",
    "                    data=model_df,\n",
    "                    x=\"safety\",\n",
    "                    y=\"goal\",\n",
    "                    scatter=False,\n",
    "                    color=\"purple\",\n",
    "                    line_kws={\"linewidth\": 1, \"alpha\": 0.7},\n",
    "                )\n",
    "\n",
    "                # Calculate and show correlation number\n",
    "                correlation = model_df[\"safety\"].corr(model_df[\"goal\"])\n",
    "                ax.annotate(\n",
    "                    f\"Correlation: {correlation:.2f}\",\n",
    "                    xy=(0.95, 0.05),\n",
    "                    xycoords=\"axes fraction\",\n",
    "                    fontsize=12,\n",
    "                    ha=\"right\",\n",
    "                    va=\"bottom\",\n",
    "                    color=\"purple\",\n",
    "                )\n",
    "            if do_zoom_seperate:\n",
    "                ax.axvline(x=-4, color=\"green\", linestyle=\"--\", linewidth=1)\n",
    "                ax.axhline(y=6, color=\"purple\", linestyle=\"--\", linewidth=1)\n",
    "                ax.axvspan(-4, ax.get_xlim()[1], color=\"green\", alpha=0.1)\n",
    "                ax.axhspan(6, ax.get_ylim()[1], color=\"purple\", alpha=0.1)\n",
    "\n",
    "        # Add diagonal line\n",
    "        ax.plot([-10, 0], [0, 10], ls=\"--\", c=\".3\", linewidth=3, alpha=0.3)\n",
    "\n",
    "        if if_show_codename:\n",
    "            for _, row in model_df.iterrows():\n",
    "                ax.text(\n",
    "                    row[\"safety\"], row[\"goal\"], row[\"codename\"], fontsize=9, ha=\"right\"\n",
    "                )\n",
    "\n",
    "    plt.subplots_adjust(\n",
    "        wspace=0.1\n",
    "    )  # Adjust the width space between subplots to make them more compact\n",
    "\n",
    "    # Add legend for each color\n",
    "    from matplotlib.lines import Line2D\n",
    "\n",
    "    for ax, model in zip(axes, models_list):\n",
    "        legend_elements = []\n",
    "        for intent, color in model_colors[model].items():\n",
    "            if show_intent:  # Only add legend element if intent is not empty\n",
    "                legend_elements.append(\n",
    "                    Line2D(\n",
    "                        [0],\n",
    "                        [0],\n",
    "                        marker=\"o\",\n",
    "                        color=\"w\",\n",
    "                        label=f\"{models_mapping[model]} - {intent}\",\n",
    "                        markerfacecolor=color,\n",
    "                        markersize=10,\n",
    "                    )\n",
    "                )\n",
    "        if not show_intent:\n",
    "            legend_elements.append(\n",
    "                Line2D(\n",
    "                    [0],\n",
    "                    [0],\n",
    "                    marker=\"o\",\n",
    "                    color=\"w\",\n",
    "                    label=f\"{models_mapping[model]}\",\n",
    "                    markerfacecolor=color,\n",
    "                    markersize=10,\n",
    "                )\n",
    "            )\n",
    "        # Add customized x and y labels\n",
    "        ax.set_xlabel(\"Targeted Safety Risk Score\")\n",
    "        ax.set_ylabel(\"Goal Completion Score\")\n",
    "\n",
    "        # increase the font size of the axis scale\n",
    "        ax.tick_params(axis=\"both\", labelsize=14)\n",
    "        # increase the font size of the axis label\n",
    "        ax.xaxis.label.set_size(14)\n",
    "        ax.yaxis.label.set_size(14)\n",
    "        if show_intent:\n",
    "            ax.legend(handles=legend_elements, title=\"Model - Intent\", loc=\"upper left\")\n",
    "        else:\n",
    "            ax.legend(handles=legend_elements, title=\"Model\", loc=\"upper left\")\n",
    "\n",
    "    # Improve visual spacing\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_safety_goal_scatter(\n",
    "    episodes_list,\n",
    "    save_path=\"/Users/xuhuizhou/Projects/papers/ICLR2025-HAICosystem/figures/safety_goal_scatter_llama.pdf\",\n",
    "    if_show_codename=False,\n",
    "    consider_difficulty=False,\n",
    "    show_intent=False,\n",
    "    do_regression=False,\n",
    "    do_zoom_seperate=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sec 6.2: human intents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain the data\n",
    "from sotopia.database import EpisodeLog\n",
    "\n",
    "from haicosystem.protocols import HaiEnvironmentProfile\n",
    "from haicosystem.utils import get_avg_reward\n",
    "\n",
    "\n",
    "def calculate_model_rewards(models, tags, remove_tools):\n",
    "    model_rewards = {}\n",
    "    for model, tag in zip(models, tags):\n",
    "        episodes = EpisodeLog.find(EpisodeLog.tag == tag).all()\n",
    "        benign_intent_episodes = []\n",
    "        malicious_intent_episodes = []\n",
    "        benign_intent_episodes_wo_tools = []\n",
    "        malicious_intent_episodes_wo_tools = []\n",
    "        for episode in episodes:\n",
    "            env = HaiEnvironmentProfile.get(episode.environment)\n",
    "            tools_or_not = len(env.toolkits) > 0\n",
    "            if remove_tools:\n",
    "                if not tools_or_not:\n",
    "                    benign_intent_episodes_wo_tools.append(episode)\n",
    "                    malicious_intent_episodes_wo_tools.append(episode)\n",
    "            if env.agent_intent_labels[0] == \"benign\":\n",
    "                benign_intent_episodes.append(episode)\n",
    "            else:\n",
    "                malicious_intent_episodes.append(episode)\n",
    "        try:\n",
    "            benign_binary_avg_rewards_wo_tools = get_avg_reward(\n",
    "                benign_intent_episodes_wo_tools, model, binary=True\n",
    "            )  # type: ignore\n",
    "        except Exception as e:\n",
    "            benign_binary_avg_rewards_wo_tools = {}\n",
    "\n",
    "        malicious_binary_avg_rewards_wo_tools = get_avg_reward(\n",
    "            malicious_intent_episodes_wo_tools, model, binary=True\n",
    "        )  # type: ignore\n",
    "        benign_binary_avg_rewards = get_avg_reward(\n",
    "            benign_intent_episodes, model, binary=True\n",
    "        )  # type: ignore\n",
    "        malicious_binary_avg_rewards = get_avg_reward(\n",
    "            malicious_intent_episodes, model, binary=True\n",
    "        )  # type: ignore\n",
    "        model_rewards[model] = (\n",
    "            benign_binary_avg_rewards,\n",
    "            malicious_binary_avg_rewards,\n",
    "            malicious_binary_avg_rewards_wo_tools,\n",
    "        )\n",
    "    return model_rewards\n",
    "\n",
    "\n",
    "model_rewards = calculate_model_rewards(models, tags, remove_tools=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the plot malicious vs benign"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Tuple\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def draw_overall_score_bar_plot(\n",
    "    data: Dict[\n",
    "        str,\n",
    "        Tuple[\n",
    "            Dict[str, Tuple[float, float]],\n",
    "            Dict[str, Tuple[float, float]],\n",
    "            Dict[str, Tuple[float, float]],\n",
    "        ],\n",
    "    ],\n",
    "    save_path: str,\n",
    ") -> None:\n",
    "    # Prepare data for plotting\n",
    "    models = []\n",
    "    scores = []\n",
    "    intents = []\n",
    "\n",
    "    for model_name, (benign, malicious, malicious_wo_tools) in data.items():\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(benign[\"overall_score\"][0])\n",
    "        intents.append(\"Benign\")\n",
    "\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(malicious[\"overall_score\"][0])\n",
    "        intents.append(\"Malicious\")\n",
    "\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(malicious_wo_tools[\"overall_score\"][0])\n",
    "        intents.append(\"Malicious (wo tools)\")\n",
    "\n",
    "    # Create a DataFrame for easier plotting\n",
    "    plot_data = {\"Model\": models, \"Overall Score\": scores, \"Intent\": intents}\n",
    "    # Set up the color palette\n",
    "    palette = {\n",
    "        \"Benign\": \"#20c997\",\n",
    "        \"Malicious\": \"#aca2e8\",\n",
    "        \"Malicious (wo tools)\": \"#8879de\",\n",
    "    }\n",
    "    custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False}\n",
    "    sns.set_theme(style=\"whitegrid\", rc=custom_params)\n",
    "    # Plot the data\n",
    "    plt.figure(figsize=(6, 4))\n",
    "    ax = sns.barplot(\n",
    "        x=\"Model\", y=\"Overall Score\", hue=\"Intent\", data=plot_data, palette=palette\n",
    "    )\n",
    "\n",
    "    # Adding labels and title\n",
    "    plt.xlabel(\"\")\n",
    "    plt.ylabel(\"Overall Risk Ratio\")\n",
    "    plt.ylim(0, 1)\n",
    "\n",
    "    for p in ax.patches:\n",
    "        height = p.get_height()\n",
    "        if height > 0:  # Only annotate bars with a positive height\n",
    "            ax.annotate(\n",
    "                f\"{height:.2f}\",\n",
    "                (p.get_x() + p.get_width() / 2.0, height),\n",
    "                ha=\"center\",\n",
    "                va=\"bottom\",  # Adjust the vertical alignment to be 'bottom'\n",
    "                xytext=(0, 8),\n",
    "                textcoords=\"offset points\",\n",
    "                fontsize=10,\n",
    "                color=\"black\",\n",
    "            )\n",
    "    # Position the legend to upper left\n",
    "    plt.legend(title=\"Intent\", title_fontsize=\"10\", fontsize=\"8\", loc=\"upper left\")\n",
    "\n",
    "    # Improve visual spacing\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.legend(title=\"Intent\", title_fontsize=\"10\", fontsize=\"8\")\n",
    "\n",
    "    # Improve layout\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Show the plot\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "draw_overall_score_bar_plot(\n",
    "    model_rewards,\n",
    "    save_path=\"/Users/xuhuizhou/Projects/papers/ICLR2025-HAICosystem/figures/human_intent_plot.pdf\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### malicious intent inferring ability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sotopia.database import EpisodeLog\n",
    "\n",
    "from haicosystem.protocols import HaiEnvironmentProfile\n",
    "from haicosystem.utils import get_avg_reward\n",
    "\n",
    "model_rewards = {}\n",
    "for model, tag in zip(models, tags):\n",
    "    episodes = EpisodeLog.find(EpisodeLog.tag == tag).all()\n",
    "    episodes_with_tools = []\n",
    "    episodes_wo_tools = []\n",
    "    for episode in episodes:\n",
    "        env = HaiEnvironmentProfile.get(episode.environment)\n",
    "        tools_or_not = len(env.toolkits) > 0\n",
    "        if tools_or_not:\n",
    "            episodes_with_tools.append(episode)\n",
    "        else:\n",
    "            episodes_wo_tools.append(episode)\n",
    "    print(\n",
    "        f\"the number of the datapoints for goal and risk for each model: {len(episodes_wo_tools)}\"\n",
    "    )\n",
    "    try:\n",
    "        avg_rewards_wo_tools = get_avg_reward(episodes_wo_tools, model)  # type: ignore\n",
    "    except Exception as e:\n",
    "        avg_rewards_wo_tools = {}\n",
    "    avg_rewards_with_tools = get_avg_reward(episodes_with_tools, model)  # type: ignore\n",
    "    model_rewards[model] = (avg_rewards_with_tools, avg_rewards_wo_tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False}\n",
    "sns.set_theme(style=\"whitegrid\", rc=custom_params)\n",
    "\n",
    "\n",
    "def draw_malicious_intent_bar_plot(\n",
    "    data: dict[\n",
    "        str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]\n",
    "    ],\n",
    "    save_path: str,\n",
    ") -> None:\n",
    "    # Prepare data for plotting\n",
    "    models = []\n",
    "    goal_scores = []\n",
    "    risk_scores = []\n",
    "    risk_scores_wo_tools = []\n",
    "    goal_scores_wo_tools = []\n",
    "    # the second element of the tuple is the malicious intent\n",
    "    for model_name, (benign, malicious) in data.items():\n",
    "        models.append(models_mapping[model_name])\n",
    "        risk_scores.append(benign[\"targeted_safety_risks\"][0])\n",
    "        goal_scores.append(benign[\"efficiency\"][0])\n",
    "        risk_scores_wo_tools.append(malicious[\"targeted_safety_risks\"][0])\n",
    "        goal_scores_wo_tools.append(0.0)\n",
    "    plot_data = {\n",
    "        \"Model\": models,\n",
    "        \"Efficiency\": goal_scores,\n",
    "        \"Risk\": risk_scores,\n",
    "        \"Risk (wo tools)\": risk_scores_wo_tools,\n",
    "    }\n",
    "    custom_palette = {\n",
    "        \"Efficiency\": \"#63e6be\",\n",
    "        \"Risk\": \"#ff6b6b\",\n",
    "        \"Risk (wo tools)\": \"orange\",\n",
    "    }\n",
    "    plot_data_df = pd.DataFrame(plot_data)\n",
    "\n",
    "    # Plot Efficiency and Risk as stacked bars\n",
    "    fig, ax = plt.subplots(figsize=(6, 4))\n",
    "    bar_offset = 0.5  # Adjust this value to control the gap between models\n",
    "    bar_width = 0.25\n",
    "    # Plot Efficiency and Risk as stacked bars\n",
    "    plot_data_df.plot(\n",
    "        x=\"Model\",\n",
    "        y=[\"Efficiency\", \"Risk\"],\n",
    "        kind=\"bar\",\n",
    "        stacked=True,\n",
    "        color=[custom_palette[\"Efficiency\"], custom_palette[\"Risk\"]],\n",
    "        width=bar_width,  # Increase the width of the bars to reduce the gap between models\n",
    "        ax=ax,\n",
    "        position=1 - bar_offset,\n",
    "    )\n",
    "\n",
    "    # Plot Risk (wo tools) as a separate bar\n",
    "    plot_data_df.plot(\n",
    "        x=\"Model\",\n",
    "        y=\"Risk (wo tools)\",\n",
    "        kind=\"bar\",\n",
    "        color=custom_palette[\"Risk (wo tools)\"],\n",
    "        ax=ax,\n",
    "        width=bar_width,  # Increase the width of the bars to reduce the gap between models\n",
    "        position=1 + bar_offset,  # Align the position to overlap the bars\n",
    "    )\n",
    "\n",
    "    ax.set_xticklabels(\n",
    "        ax.get_xticklabels(), rotation=0\n",
    "    )  # Set x-axis labels to horizontal\n",
    "\n",
    "    ax.set_xlabel(\"\")  # Remove the x-axis label\n",
    "    ax.legend(fontsize=\"x-small\")  # Set smaller legend\n",
    "    # Add numbers on each bar, excluding 0.0\n",
    "    for p in ax.patches:\n",
    "        height = p.get_height()\n",
    "        if height != 0.0:  # Only annotate if height is not 0.0\n",
    "            ax.annotate(\n",
    "                format(height, \".2f\"),\n",
    "                (p.get_x() + p.get_width() / 2.0, height),\n",
    "                ha=\"center\",\n",
    "                va=\"center\",\n",
    "                fontsize=\"x-small\",  # Make the text smaller\n",
    "                xytext=(0, 9 if height > 0 else -9),\n",
    "                textcoords=\"offset points\",\n",
    "            )\n",
    "    ax.set_ylim(bottom=-8)  # Increase the y limit to -9\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "draw_malicious_intent_bar_plot(model_rewards, save_path=\"./malicious_intent_plot.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sec 6.3: Access to the tools\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "def draw_efficiency_goal_bar_plot(\n",
    "    data: Dict[str, Dict[str, Tuple[float, float]]], save_path: str = None\n",
    ") -> None:\n",
    "    # Prepare data for plotting\n",
    "    models = []\n",
    "    scores = []\n",
    "    metrics = []\n",
    "\n",
    "    for model_name, metrics_dict in data.items():\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(\n",
    "            metrics_dict[\"efficiency\"][0]\n",
    "        )  # Use the first element of the tuple for efficiency\n",
    "        metrics.append(\"Efficiency\")\n",
    "\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(\n",
    "            metrics_dict[\"goal\"][0]\n",
    "        )  # Use the first element of the tuple for goal\n",
    "        metrics.append(\"Goal\")\n",
    "\n",
    "        models.append(models_mapping[model_name])\n",
    "        scores.append(\n",
    "            metrics_dict[\"targeted_safety_risks\"][0]\n",
    "        )  # Use the first element of the tuple for efficiency\n",
    "        metrics.append(\"targeted_safety_risks\")\n",
    "\n",
    "    # Create a DataFrame for easier plotting\n",
    "    plot_data = {\"Model\": models, \"Score\": scores, \"Metric\": metrics}\n",
    "\n",
    "    # Set up the color palette\n",
    "    palette = {\n",
    "        \"Efficiency\": \"#69db7c\",\n",
    "        \"Goal\": \"#4dabf7\",\n",
    "        \"targeted_safety_risks\": \"#ff6b6b\",\n",
    "    }\n",
    "\n",
    "    # Plot the data\n",
    "    plt.figure(figsize=(6, 4))\n",
    "    ax = sns.barplot(\n",
    "        x=\"Model\", y=\"Score\", hue=\"Metric\", data=plot_data, palette=palette\n",
    "    )\n",
    "\n",
    "    # Adding labels and title\n",
    "    plt.xlabel(\"\")\n",
    "    plt.ylabel(\"Score\")\n",
    "    plt.ylim(min(scores) - 1, max(scores) + 1)\n",
    "\n",
    "    # Add the scores above the bars\n",
    "    for p in ax.patches:\n",
    "        height = p.get_height()\n",
    "        ax.annotate(\n",
    "            f\"{height:.2f}\".rstrip(\"0\").rstrip(\".\").rstrip(\"0\"),\n",
    "            (p.get_x() + p.get_width() / 2.0, height),\n",
    "            ha=\"center\",\n",
    "            va=\"bottom\",  # Adjust the vertical alignment to be 'bottom'\n",
    "            xytext=(0, 8),\n",
    "            textcoords=\"offset points\",\n",
    "            fontsize=10,\n",
    "            color=\"black\",\n",
    "        )\n",
    "\n",
    "    # Improve visual spacing\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    labels = [\n",
    "        \"Efficiency\" if label == \"Efficiency\" else \"Goal\" if label == \"Goal\" else \"Targ\"\n",
    "        for label in labels\n",
    "    ]\n",
    "    plt.legend(handles, labels, title=\"Metric\", title_fontsize=\"10\", fontsize=\"8\")\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "    # Improve layout\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Show the plot\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "draw_efficiency_goal_bar_plot(\n",
    "    performance_data,\n",
    "    save_path=\"/Users/xuhuizhou/Projects/papers/ICLR2025-HAICosystem/figures/access_to_tools_plot.pdf\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Correlation between efficiency and safety risks\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sec 6.4: Correlation between efficiency and safety risks\n",
    "\n",
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import scipy.stats as stats\n",
    "\n",
    "\n",
    "def calculate_correlation(data: Dict[str, Dict[str, Tuple[float, float]]]) -> float:\n",
    "    # Extract efficiency and safety risks scores\n",
    "    efficiency_scores = [\n",
    "        metrics_dict[\"efficiency\"][0] for metrics_dict in data.values()\n",
    "    ]\n",
    "    print(efficiency_scores)\n",
    "    safety_risks_scores = [\n",
    "        metrics_dict[\"targeted_safety_risks\"][0] for metrics_dict in data.values()\n",
    "    ]\n",
    "    print(safety_risks_scores)\n",
    "\n",
    "    # Calculate the correlation coefficient\n",
    "    correlation_coefficient, _ = stats.pearsonr(efficiency_scores, safety_risks_scores)\n",
    "\n",
    "    return correlation_coefficient\n",
    "\n",
    "\n",
    "# Calculate the correlation coefficient\n",
    "correlation_coefficient = calculate_correlation(performance_data)\n",
    "\n",
    "print(\n",
    "    f\"The correlation coefficient between efficiency and safety risks is: {correlation_coefficient:.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## correlation with all episodes\n",
    "def calculate_correlation(episodes: list[EpisodeLog]) -> float:\n",
    "    # Extract efficiency and safety risks scores\n",
    "    efficiency_scores = [episode.rewards[1][1][\"efficiency\"] for episode in episodes]\n",
    "    safety_risks_scores = [\n",
    "        episode.rewards[1][1][\"targeted_safety_risks\"] for episode in episodes\n",
    "    ]\n",
    "\n",
    "    # Calculate the correlation coefficient\n",
    "    correlation_coefficient, _ = stats.pearsonr(efficiency_scores, safety_risks_scores)\n",
    "\n",
    "    return correlation_coefficient\n",
    "\n",
    "\n",
    "correlation_coefficient = calculate_correlation(all_episodes)\n",
    "print(\n",
    "    f\"The correlation coefficient between efficiency and safety risks is: {correlation_coefficient:.4f}\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "haicosystem",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
