{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0XpUAPcdgNrX",
        "outputId": "9886c826-561e-4c75-8954-e75c0cf021c1"
      },
      "outputs": [],
      "source": [
        "!pip install datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yGD9qWwfwxyx",
        "outputId": "a405e973-28c6-41b5-fa3d-16a151dfb981"
      },
      "outputs": [],
      "source": [
        "!pip install lmfit"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {
        "id": "XZU-6FikfK90"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import tqdm\n",
        "# import torch\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import json\n",
        "from datasets import load_dataset\n",
        "import os\n",
        "import math"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "\n",
        "root = os.path.dirname(os.getcwd())\n",
        "info_dict_harmfulness1 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': True}\n",
        "info_dict_harmfulness2 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_harmfulness3 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-chat-hf',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': True}\n",
        "info_dict_harmfulness4 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-chat-hf',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_harmfulness5 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-hf',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': True}\n",
        "info_dict_harmfulness6 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-hf',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_harmfulness7 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': True}\n",
        "info_dict_harmfulness8 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'tasks': 'harmfulness',\n",
        "            'do_sample': False}\n",
        "\n",
        "\n",
        "\n",
        "info_dict_fairness1 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': True}\n",
        "info_dict_fairness2 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_fairness3 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': True}\n",
        "info_dict_fairness4 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_fairness5 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-chat-hf',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': True}\n",
        "info_dict_fairness6 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-chat-hf',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': False}\n",
        "\n",
        "info_dict_fairness7 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-hf',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': True}\n",
        "info_dict_fairness8 = {'output_dir': f'{root}/data/fairness_experiments_outputs/...',\n",
        "            'model_name': 'meta-llama/Llama-2-13b-hf',\n",
        "            'tasks': 'fairness',\n",
        "            'do_sample': False}\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_elements_until_key(my_dict, stop_key):\n",
        "    result = {}\n",
        "    for key, value in my_dict.items():\n",
        "        result[key] = value\n",
        "        if key == stop_key:\n",
        "            break\n",
        "    return result\n",
        "\n",
        "def get_elements_from_a_key(my_dict, start_key):\n",
        "    result = {}\n",
        "    flag = False\n",
        "    for key, value in my_dict.items():\n",
        "        if key == start_key:\n",
        "            flag = True\n",
        "        if flag:\n",
        "            result[key] = value\n",
        "    return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "id": "xp3W78WrWGQA"
      },
      "outputs": [],
      "source": [
        "# Define the quadratic functions\n",
        "def corollary_1_quadratic_function(x, a, b):\n",
        "    return (0.5 / (1+ a + a * (((b**2)/2) * (x**2)))) + 0.25\n",
        "\n",
        "def corollary_1_quadratic_function_no_bias(x, a, b):\n",
        "    return (1 / (1+ a + a * (((b**2)/2) * (x**2))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {},
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from scipy.optimize import curve_fit\n",
        "from lmfit import Model\n",
        "\n",
        "model = Model(corollary_1_quadratic_function)\n",
        "model_no_bias = Model(corollary_1_quadratic_function_no_bias)\n",
        "\n",
        "# Define your (x, y) coordinates\n",
        "x_data = list(np.round(np.arange(-5, 5.2, 0.25), 2))\n",
        "x_full = list(np.round(np.arange(-10, 10.2, 0.25), 2))\n",
        "x_data_idx = [i for i, v in enumerate(x_full) if v in x_data]\n",
        "\n",
        "\n",
        "ylabel = 'probability of correct answer'\n",
        "ylabel_relative = 'probability of correct answer out of A, B, C, D'\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 106,
      "metadata": {},
      "outputs": [],
      "source": [
        "info_dict = info_dict_fairness8\n",
        "\n",
        "x = list(np.round(np.arange(-10, 10.2, 0.25), 2))\n",
        "dataset_names = ['medical_genetics', 'high_school_computer_science', 'international_law']\n",
        "model_names = [info_dict['model_name']]\n",
        "tasks = [info_dict['tasks']]\n",
        "output_dir = info_dict['output_dir']\n",
        "\n",
        "p_mean = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}\n",
        "p_mean_relative = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}\n",
        "acc_mean = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}\n",
        "p_std = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}\n",
        "p_std_relative = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}\n",
        "acc_std = {task_key: {model_name_key: {mmlu_key: {key: 0 for key in x} for mmlu_key in dataset_names} for model_name_key in model_names} for task_key in tasks}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 107,
      "metadata": {},
      "outputs": [],
      "source": [
        "from itertools import chain\n",
        "\n",
        "start_key = -10.0\n",
        "stop_key = 10.1\n",
        "\n",
        "for task in tasks:\n",
        "    for model_name in model_names:\n",
        "        for dataset_name in dataset_names:\n",
        "        # for i, coeff in enumerate(x):\n",
        "            with open(f'{output_dir}/{dataset_name}/helpfulness_{task}_{model_name.replace(\"/\",\"_\")}_stats_sample.json', 'r') as file:\n",
        "                stats_dict = json.load(file)\n",
        "                stats_dict = eval(stats_dict.strip())                \n",
        "            p_mean[task][model_name][dataset_name] = stats_dict['p_mean']\n",
        "            p_mean_relative[task][model_name][dataset_name] = stats_dict['p_mean_relative']\n",
        "            p_std[task][model_name][dataset_name] = stats_dict['p_std']\n",
        "            p_std_relative[task][model_name][dataset_name] = stats_dict['p_std_relative']\n",
        "            acc_mean[task][model_name][dataset_name] = stats_dict['acc_mean']\n",
        "            acc_std[task][model_name][dataset_name] = stats_dict['acc_std']\n",
        "        "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 108,
      "metadata": {
        "id": "BnlLvNIDgWdc"
      },
      "outputs": [],
      "source": [
        "def plot_figure_mmlu_multi(x, y, y_err, y_data_fit, ylabel, y_tagging, mmlu_dataset_names, task, model_name, f_fit, fit_label=''):\n",
        "    def moving_average(data, window_size):\n",
        "      return np.convolve(data, np.ones(window_size)/window_size, mode='same')\n",
        "\n",
        "    # Set initial parameter values and bounds (if needed)\n",
        "    params = f_fit.make_params(a=1, b=1)\n",
        "    # Optionally, you can set parameter bounds\n",
        "    params['a'].min = 0\n",
        "    params['b'].min = 0\n",
        "    # Perform the fit\n",
        "    result = f_fit.fit(y_data_fit, x=x_data, params=params)\n",
        "    if 'A, B, C, D' in ylabel:\n",
        "      # the current plot is over the resricted vocabulary (|V|={A,B,C,D}) -> bias=0.5\n",
        "      y_fit = [corollary_1_quadratic_function(x_elem, result.params['a'].value, result.params['b'].value/2.5) for x_elem in x]\n",
        "    else:\n",
        "      # the current plot is over the entire vocabulary meaning bias->2\n",
        "      y_fit = [corollary_1_quadratic_function_no_bias(x_elem, result.params['a'].value, result.params['b'].value/2.5) for x_elem in x]\n",
        "\n",
        "    # Create a plot\n",
        "    short_model_name = model_name.split('/')[1]\n",
        "    folder_name = f'{root}/data/plots/helpfulness_plots/{short_model_name}_{task}/'\n",
        "    os.makedirs(folder_name, exist_ok=True)\n",
        "    plt.figure(figsize=(8, 6))  # Adjust the figure size as needed\n",
        "\n",
        "    # Plot x vs y\n",
        "    window_size=5\n",
        "    fitted_label=''\n",
        "    for mmlu_dataset in mmlu_dataset_names:\n",
        "        y_plot = np.array(list(y[mmlu_dataset].values()))\n",
        "        if y_tagging != 'p' or 'A, B, C, D' not in ylabel:\n",
        "            y_plot = moving_average(y_plot, window_size)\n",
        "        y_err_plot = np.array(list(y_err[mmlu_dataset].values())) / 10 # The standard error is std/sqrt(n). in our case n=100 for all mmlu sub-datasets\n",
        "        mmlu_dataset_label = mmlu_dataset.replace(\"computer_science\", 'CS').replace('_', ' ')\n",
        "        plt.plot(np.array(x), y_plot, label = f'{mmlu_dataset_label}')  # Adjust marker and linestyle as needed\n",
        "        plt.fill_between(np.array(x), y_plot - y_err_plot, y_plot + y_err_plot, alpha=0.2)\n",
        "    # if 'Llama-2' in model_name and 'A, B, C, D' not in ylabel and y_tagging == 'p':\n",
        "    if fit_label == '_fitted':\n",
        "      plt.plot(x, y_fit, label = f'fitted')\n",
        "\n",
        "    # Add labels and title\n",
        "    plt.xlabel(r\"$r_e$\")\n",
        "    if y_tagging == 'acc':\n",
        "      plt.ylabel('Accuracy')\n",
        "    else:\n",
        "      plt.ylabel('P(correct)')\n",
        "\n",
        "    \n",
        "    tail_title = ''\n",
        "    if 'A, B, C, D' in ylabel:\n",
        "      tail_title = '- restricted to A, B, C, D'\n",
        "    if task == 'harmfulness':\n",
        "      plt.title(f'Harmfulness behavior - {short_model_name} {tail_title}')\n",
        "    else:\n",
        "      plt.title(f'Fairness behavior - {short_model_name} {tail_title}')\n",
        "    plt.legend()\n",
        "\n",
        "    # Display the plot\n",
        "    # plt.grid(True)  # Add gridlines if desired\n",
        "    do_sample = info_dict['do_sample']\n",
        "    if y_tagging == 'acc':\n",
        "      plt.savefig(os.path.join(folder_name, f'acc_{short_model_name}_{task}_helpfulness_{do_sample}{fit_label}.png'))\n",
        "    elif 'A, B, C, D' in ylabel and y_tagging == 'p':\n",
        "      plt.savefig(os.path.join(folder_name, f'p_relative_{short_model_name}_{task}_helpfulness_{do_sample}{fit_label}.png'))\n",
        "    else:\n",
        "      plt.savefig(os.path.join(folder_name, f'p_{short_model_name}_{task}_helpfulness_{do_sample}{fit_label}.png'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "3lcoUvyBqTXB",
        "outputId": "1895b069-2b96-44cd-e6b5-4c3795eee64c"
      },
      "outputs": [],
      "source": [
        "for task in tasks:\n",
        "  for model_name in model_names:\n",
        "    dataset_to_fit = 'medical_genetics' if \"Llama-3\" in model_name else 'international_law'\n",
        "    y_data_fit_relative = [round(v,2) for i, v in enumerate(list(p_mean_relative[task][model_name][dataset_to_fit].values())) if i in x_data_idx]\n",
        "    y_data_fit = [round(v,2) for i, v in enumerate(list(p_mean[task][model_name][dataset_to_fit].values())) if i in x_data_idx]\n",
        "    # x, y, y_data_fit, ylabel, mmlu_dataset_names, task, model\n",
        "    plot_figure_mmlu_multi(x=x_full, y=p_mean_relative[task][model_name], y_err=p_std_relative[task][model_name], y_data_fit=np.array(y_data_fit_relative), ylabel=ylabel_relative, y_tagging='p', mmlu_dataset_names=dataset_names, task=task, model_name=model_name, f_fit=model, fit_label='')\n",
        "    plot_figure_mmlu_multi(x=x_full, y=p_mean[task][model_name], y_err=p_std[task][model_name], y_data_fit=np.array(y_data_fit), ylabel=ylabel, y_tagging='p', mmlu_dataset_names=dataset_names, task=task, model_name=model_name, f_fit=model_no_bias, fit_label='_fitted')\n",
        "    plot_figure_mmlu_multi(x=x_full, y=p_mean[task][model_name], y_err=p_std[task][model_name], y_data_fit=np.array(y_data_fit), ylabel=ylabel, y_tagging='p', mmlu_dataset_names=dataset_names, task=task, model_name=model_name, f_fit=model_no_bias, fit_label='')\n",
        "    plot_figure_mmlu_multi(x=x_full, y=acc_mean[task][model_name], y_err=acc_std[task][model_name], y_data_fit=np.array(y_data_fit), ylabel=ylabel, y_tagging='acc', mmlu_dataset_names=dataset_names, task=task, model_name=model_name, f_fit=model_no_bias, fit_label='')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "AX7XpmGOg2Mf"
      },
      "outputs": [],
      "source": [
        "root = os.path.dirname(os.getcwd())\n",
        "\n",
        "def plot_figure_behavior_multi(x, y, y_err, y_fit, task, title, y_label, model_name, fit_label=''):\n",
        "    # Create a plot\n",
        "    folder_name = f'{root}/data/plots/behavior_plots/{model_name}_alignment/'\n",
        "    os.makedirs(folder_name, exist_ok=True)\n",
        "    plt.figure(figsize=(8, 6))  # Adjust the figure size as needed\n",
        "\n",
        "    # Plot x vs y\n",
        "    plt.plot(x, y)  # Adjust marker and linestyle as needed\n",
        "    if fit_label == '_fitted':\n",
        "      plt.plot(x, y_fit, label = f'fitted')\n",
        "    y_err_plot = np.array(y_err) / 10 # The standard error is std/sqrt(n). in our case n=100 for all mmlu sub-datasets\n",
        "    plt.fill_between(np.array(x), np.array(y) - y_err_plot, np.array(y) + y_err_plot, alpha=0.2)\n",
        "\n",
        "    # Add labels and title\n",
        "    plt.xlabel(r\"$r_e$\")\n",
        "    plt.ylabel(y_label)\n",
        "    plt.title(title)\n",
        "    plt.legend()\n",
        "    if 'A, B, C, D' in ylabel:\n",
        "      plt.savefig(os.path.join(folder_name, f'relative_{model_name}_{task}{fit_label}.png'))\n",
        "    else:\n",
        "      plt.savefig(os.path.join(folder_name, f'{model_name}_{task}{fit_label}.png'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "info_dict1 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/safety_...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'behavior_types': 'safety_harmfulness'}\n",
        "info_dict2 = {'output_dir': f'{root}/data/harmfulness_experiments_outputs/safety_...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'behavior_types': 'safety_harmfulness'}\n",
        "info_dict3 = {'output_dir': f'{root}/data/fairness_experiments_outputs/safety_...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct',\n",
        "            'behavior_types': 'safety_fairness'}\n",
        "info_dict4 = {'output_dir': f'{root}/data/fairness_experiments_outputs/safety_...',\n",
        "            'model_name': 'meta-llama/Meta-Llama-3.1-8B',\n",
        "            'behavior_types': 'safety_fairness'}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "JzypXldTcxpF"
      },
      "outputs": [],
      "source": [
        "info_dict = info_dict2\n",
        "output_dir = info_dict['output_dir']\n",
        "\n",
        "behavior_types = [info_dict['behavior_types']]\n",
        "model_names = [info_dict['model_name'].replace(\"/\",\"_\")]\n",
        "behavior_mean = {model_name_key: {behavior_type_key: {key: 0 for key in x} for behavior_type_key in behavior_types} for model_name_key in model_names}\n",
        "behavior_std = {model_name_key: {behavior_type_key: {key: 0 for key in x} for behavior_type_key in behavior_types} for model_name_key in model_names}\n",
        "y_labels = {model_name_key: {behavior_type_key: {key: '' for key in x} for behavior_type_key in behavior_types} for model_name_key in model_names}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "v-nt6iaCc1Fe"
      },
      "outputs": [],
      "source": [
        "for model_name in model_names:\n",
        "  for behavior_type in behavior_types:\n",
        "    if behavior_type == 'safety_harmfulness':\n",
        "      task = 'harmfulness'\n",
        "      stats_dict_mean = 'behavior_harmful_mean'\n",
        "      stats_dict_std = 'behavior_harmful_std'\n",
        "    else:\n",
        "      task = 'fairness'\n",
        "      stats_dict_mean = 'behavior_harmful_mean' # 'behavior_bias_mean'\n",
        "      stats_dict_std = 'behavior_harmful_std' # 'behavior_bias_std'\n",
        "    with open(f'{output_dir}/{behavior_type}_{model_name}_stats_sample.json', 'r') as file:\n",
        "      stats_dict = json.load(file)\n",
        "    behavior_mean[model_name][behavior_type] = stats_dict[stats_dict_mean]\n",
        "    behavior_std[model_name][behavior_type] = stats_dict[stats_dict_std]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "PotyiSgxoI8v",
        "outputId": "8b7db1a0-c0a7-4a05-a447-40d524dea7bc"
      },
      "outputs": [],
      "source": [
        "\n",
        "import numpy as np\n",
        "from scipy.optimize import curve_fit\n",
        "from lmfit import Model\n",
        "\n",
        "# Define the quadratic function\n",
        "def theory_1_quadratic_function(x, a, b0):\n",
        "    return np.array([math.tanh((a * x_elem) + b0) for x_elem in x])\n",
        "\n",
        "model = Model(theory_1_quadratic_function)\n",
        "\n",
        "# Define your (x, y) coordinates\n",
        "x_data = [round(i,2) for i in np.arange(-3.2, 4.1, 0.2)]\n",
        "x_full = [round(i,2) for i in np.arange(-5, 5.1, 0.2)]\n",
        "x_data_idx = [i for i, v in enumerate(x_full) if v in x_data]\n",
        "\n",
        "b0_idx = [i for i, v in enumerate(x_data) if v == 0][0]\n",
        "\n",
        "for behavior_type in behavior_types:\n",
        "  for model_name in model_names:\n",
        "    y_data_fit = [round(v,2) for i, v in enumerate(list(behavior_mean[model_name][behavior_type].values())) if i in x_data_idx]\n",
        "    y_data_fit = np.array(y_data_fit)\n",
        "    # if y_data_fit[0] > 0:\n",
        "    #   y_data_fit = [y_elem * (-1) for y_elem in y_data_fit]\n",
        "\n",
        "    # set the plot title\n",
        "    title_prefix = 'Agrees' if ('chat' in model_name or 'Instruct' in model_name) else 'Refuses'\n",
        "    title_suffix = 'harmful instructions' if 'harmful' in behavior_type else 'racist statements'\n",
        "    short_model_name = model_name.split(\"_\")[1]\n",
        "    title = title_prefix + ' ' + title_suffix + f'- {short_model_name}'\n",
        "\n",
        "    # set the y_label title\n",
        "    y_label = ''\n",
        "    if ('bias' in behavior_type):\n",
        "      y_label = 'P(\"Yes\") - P(\"No\")' if ('chat' in model_name or 'Instruct' in model_name) else 'P(\"No\") - P(\"Yes\")'\n",
        "    else:\n",
        "      y_label = 'P(Agrees) - P(Refuses)' if ('chat' in model_name or 'Instruct' in model_name) else 'P(Refuses) - P(Agrees)'\n",
        "\n",
        "    if behavior_type == 'behavior_harmful':\n",
        "      task = 'harmfulness'\n",
        "    else:\n",
        "      task = 'fairness'\n",
        "\n",
        "    # Set initial parameter values and bounds (if needed)\n",
        "    params = model.make_params(a=2, b0=-2)\n",
        "    # Perform the fit\n",
        "    result = model.fit(y_data_fit, x=x_data, params=params)\n",
        "\n",
        "    y_fit = [theory_1_quadratic_function([x], result.params['a'].value, result.params['b0'].value) for x in x_data]\n",
        "\n",
        "    y = [v for i,v in enumerate(list(behavior_mean[model_name][behavior_type].values())) if i in x_data_idx]\n",
        "    y_err = [v for i,v in enumerate(list(behavior_std[model_name][behavior_type].values())) if i in x_data_idx]\n",
        "\n",
        "    # params: x, y, y_err, y_fit, task, title, y_label, model_name\n",
        "    plot_figure_behavior_multi(x=x_data, y=y, y_err=y_err, y_fit=y_fit, task=task, title=title, y_label=y_label, model_name=model_name, fit_label='_fitted')\n",
        "    plot_figure_behavior_multi(x=x_data, y=y, y_err=y_err, y_fit=y_fit, task=task, title=title, y_label=y_label, model_name=model_name, fit_label='')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
