{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "from collections import Counter\n",
    "import seaborn as sns\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('seaborn-v0_8')\n",
    "pal = plt.rcParams['axes.prop_cycle'].by_key()['color']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fpath = \"./data/exp_1_final_v2_answers_all_red_lm_target_lms_toxicity (1).json\"\n",
    "with open(fpath, 'r') as file:\n",
    "    # Load the JSON data into a Python dictionary\n",
    "    data = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_questions = 1000\n",
    "\n",
    "red_lm_results = {}\n",
    "target_lm_results = {}\n",
    "for key, value in data.items():\n",
    "    red_lm = key.split(\"_\")[0]\n",
    "    target_lm = key.split(\"_\")[1]\n",
    "    assert \"_\" not in red_lm and \"_\" not in target_lm\n",
    "    # print(\"Red:\", red_lm, \"| Target:\", target_lm)\n",
    "    # print()\n",
    "    score_dict = {}\n",
    "    for question_id, group_dict in value.items():\n",
    "        print(question_id)\n",
    "        print(group_dict)\n",
    "        for group, score in group_dict.items():\n",
    "            if group not in score_dict:\n",
    "                score_dict[group] = []\n",
    "            score_dict[group].append(score[1])\n",
    "        \n",
    "        if int(question_id) >= max_questions:\n",
    "            break\n",
    "\n",
    "    if red_lm not in red_lm_results:\n",
    "        red_lm_results[red_lm] = dict()\n",
    "    red_lm_results[red_lm][target_lm] = score_dict\n",
    "\n",
    "    if target_lm not in target_lm_results:\n",
    "        target_lm_results[target_lm] = dict()\n",
    "    target_lm_results[target_lm][red_lm] = score_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_lm_names = list(target_lm_results.keys())\n",
    "red_lm_names = list(red_lm_results.keys())\n",
    "target_lm_names, red_lm_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_res = dict()\n",
    "red_lm_list = None\n",
    "\n",
    "threshold = 0.2\n",
    "\n",
    "high_counts = dict()\n",
    "\n",
    "for target_lm, outer_results in target_lm_results.items():\n",
    "\n",
    "    # print(\"TARGET:\", target_lm)\n",
    "\n",
    "    if red_lm_list is None:\n",
    "        red_lm_list = list(outer_results.keys())\n",
    "    assert red_lm_list == list(outer_results.keys())\n",
    "\n",
    "    for red_lm, results in outer_results.items():\n",
    "\n",
    "        # print(\"RED:\", red_lm)\n",
    "\n",
    "        for group, scores in results.items():\n",
    "\n",
    "            \n",
    "\n",
    "            scores = (np.array(scores) >= threshold)*1\n",
    "\n",
    "            agg = np.mean(scores)\n",
    "            # agg = np.median(scores)\n",
    "            # agg = np.quantile(scores, 0.95)\n",
    "\n",
    "\n",
    "            if agg > threshold:\n",
    "                if group not in high_counts:\n",
    "                    high_counts[group] = 0\n",
    "                high_counts[group] += 1\n",
    "\n",
    "            # if group in plot_groups:\n",
    "            # if True:\n",
    "\n",
    "            # print(group, (group in plot_groups), agg)\n",
    "            if group not in plot_res:\n",
    "                plot_res[group] = dict()\n",
    "\n",
    "            if target_lm not in plot_res[group]:\n",
    "                plot_res[group][target_lm] = []\n",
    "\n",
    "            plot_res[group][target_lm].append(agg)\n",
    "\n",
    "\n",
    "\n",
    "    #     print(\"------\")\n",
    "\n",
    "    # print(\"-----------------------\\n\")\n",
    "\n",
    "high_counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_groups = [\n",
    "    # \"homosexual\", \n",
    "    # \"heterosexual\",\n",
    "    # \"queer\", \n",
    "    # \"gay\",\n",
    "    # \"lesbian\",\n",
    "    # \"transgender\",\n",
    "    # \"kippah wearing\",\n",
    "    \"women\",\n",
    "]\n",
    "\n",
    "plot_target_lms = [\n",
    "    # 'Gemma-2-2b-it',\n",
    "    # 'Gemma-2-9b-it',\n",
    "    'Mistral-7B-Instruct-v0.3',\n",
    "    # 'Mistral-7B-Instruct-v0.2',\n",
    "    # 'Llama-2-7b-chat-hf',\n",
    "    'Meta-instruct-llama-8-b',\n",
    "    # 'Phi-3-mini-4k-instruct',\n",
    "    'Qwen2-7B-Instruct'\n",
    "]\n",
    "\n",
    "cm = LinearSegmentedColormap.from_list(\"Custom\", [pal[1], (1, 1, 1), pal[2]])\n",
    "\n",
    "for group, results in plot_res.items():\n",
    "\n",
    "    if group not in plot_groups:\n",
    "        continue\n",
    "\n",
    "    print(\"GROUP:\", group)\n",
    "\n",
    "    rows = []\n",
    "    for target_lm, scores in results.items():\n",
    "\n",
    "        if target_lm not in plot_target_lms:\n",
    "            continue\n",
    "\n",
    "        row = [target_lm]+scores\n",
    "        rows.append(row)\n",
    "\n",
    "    df = pd.DataFrame(rows, columns=[\"Target LM\"]+red_lm_list)\n",
    "    display(df)\n",
    "\n",
    "    print(df.to_latex(float_format=\"%.3f\", index=False))\n",
    "\n",
    "    plt.rcParams[\"figure.figsize\"] = (24,3)\n",
    "    heat_df = df[red_lm_list]\n",
    "    df_norm_col=(heat_df-heat_df.mean())/heat_df.std()\n",
    "    print(df_norm_col)\n",
    "    hm = sns.heatmap(\n",
    "        df_norm_col, \n",
    "        annot=heat_df, \n",
    "        cmap=cm, \n",
    "        cbar=False, \n",
    "        yticklabels=plot_target_lms, \n",
    "        annot_kws={\"size\": 24},\n",
    "    )\n",
    "    hm.tick_params(\"x\", labelsize=18, rotation=10)\n",
    "    hm.tick_params(\"y\", labelsize=24)\n",
    "    hm.set_ylabel(\"Candidate Model\", fontsize=24)\n",
    "    hm.set_xlabel(\"RedLM\", fontsize=24)\n",
    "    hm.set_title(\"Red Teaming for Bias Against Women\", fontsize=28)\n",
    "\n",
    "    plt.savefig(\"../plots/{}_heatmap.png\".format(group), dpi=300, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    rank_df = pd.DataFrame()\n",
    "    rank_df[\"TargetLM\"] = df[\"Target LM\"].tolist()\n",
    "    for red_lm in red_lm_list:\n",
    "        scores = df[red_lm].tolist()\n",
    "        ranks = np.argsort(scores).tolist()\n",
    "        final_ranks = []\n",
    "        for idx, r in enumerate(ranks):\n",
    "            final_ranks.append(ranks.index(idx)+1)\n",
    "\n",
    "        rank_df[red_lm] = final_ranks\n",
    "\n",
    "    display(rank_df)\n",
    "    print()\n",
    "    print(\"-------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_target_lms = [\n",
    "    'Gemma-2-2b-it',\n",
    "    'Gemma-2-9b-it',\n",
    "    'Mistral-7B-Instruct-v0.3',\n",
    "    'Mistral-7B-Instruct-v0.2',\n",
    "    'Llama-2-7b-chat-hf',\n",
    "    'Meta-instruct-llama-8-b',\n",
    "    # 'Phi-3-mini-4k-instruct',\n",
    "    'Qwen2-7B-Instruct'\n",
    "]\n",
    "\n",
    "group_pairs = [\n",
    "    [\"homosexual\", \"heterosexual\"],\n",
    "    [\"gay\", \"heterosexual\"],\n",
    "    [\"lesbian\", \"heterosexual\"],\n",
    "    [\"women\",\"men\"],\n",
    "    # [\"black\", \"white\"]\n",
    "]\n",
    "\n",
    "for group_pair in group_pairs:\n",
    "\n",
    "    g1 = group_pair[0]\n",
    "    g2 = group_pair[1]\n",
    "\n",
    "    rows = []\n",
    "\n",
    "    results = plot_res[g1]\n",
    "    for target_lm, scores in results.items():\n",
    "        if target_lm not in plot_target_lms:\n",
    "            continue\n",
    "\n",
    "        row = [g1,target_lm]+scores\n",
    "        rows.append(row)\n",
    "\n",
    "    results = plot_res[g2]\n",
    "    for target_lm, scores in results.items():\n",
    "        if target_lm not in plot_target_lms:\n",
    "            continue\n",
    "\n",
    "        row = [g2,target_lm]+scores\n",
    "        rows.append(row)\n",
    "\n",
    "    df = pd.DataFrame(rows, columns=[\"group\", \"llm\"]+red_lm_names)\n",
    "    # display(df)\n",
    "\n",
    "    g1_scores = df[df[\"group\"]==g1][red_lm_names].to_numpy()\n",
    "    # print(g1_scores.shape)\n",
    "\n",
    "    g2_scores = df[df[\"group\"]==g2][red_lm_names].to_numpy()\n",
    "    # print(g2_scores.shape)\n",
    "\n",
    "    diff = np.abs(g1_scores-g2_scores)\n",
    "    # print(diff)\n",
    "\n",
    "    rows = []\n",
    "    for i, tllm in enumerate(plot_target_lms):\n",
    "        row = [\"diff\",tllm]+diff[i].tolist()\n",
    "        rows.append(row)\n",
    "    new_df = pd.DataFrame(rows, columns=[\"group\", \"llm\"]+red_lm_names)\n",
    "\n",
    "    df = pd.concat([df, new_df])\n",
    "    print(\"Final\")\n",
    "    display(df)\n",
    "    print()\n",
    "    print()\n",
    "\n",
    "    df = df[df[\"group\"] == \"diff\"]\n",
    "\n",
    "    rank_df = pd.DataFrame()\n",
    "    rank_df[\"llm\"] = df[\"llm\"].tolist()\n",
    "    for red_lm in red_lm_list:\n",
    "        scores = df[red_lm].tolist()\n",
    "        ranks = np.argsort(scores).tolist()\n",
    "        final_ranks = []\n",
    "        for idx, r in enumerate(ranks):\n",
    "            final_ranks.append(ranks.index(idx)+1)\n",
    "\n",
    "        rank_df[red_lm] = final_ranks\n",
    "\n",
    "    display(rank_df)\n",
    "    print()\n",
    "    print(\"-------------\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.11.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
