{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "path = \"../../results/faithfulness_baselines_updated\"\n",
    "model_name = \"llama-13b\"\n",
    "fig_dir = f\"figs/{model_name}\"\n",
    "############################################\n",
    "os.makedirs(fig_dir, exist_ok=True)\n",
    "from scripts.baselines.faithfulness_baselines import load_raw_results\n",
    "\n",
    "results_raw = load_raw_results(\n",
    "    model_name, results_path=path, \n",
    "    multiple_files=False\n",
    "    # multiple_files=\"llama\" in model_name\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_none(arr):\n",
    "    return [x for x in arr if x is not None]\n",
    "\n",
    "def format_results(results_raw):\n",
    "    results_formatted = {}\n",
    "    for relation_results in results_raw:\n",
    "        result = {k: v for k, v in relation_results.items() if k != \"trials\"}\n",
    "        result[\"recall\"] = {}\n",
    "        if len(relation_results[\"trials\"]) == 0:\n",
    "            continue\n",
    "        for trial_result in relation_results[\"trials\"]:\n",
    "            prompting_approaches = [\"zero_shot\", \"icl\"]\n",
    "            for approach in prompting_approaches:\n",
    "                if approach not in result[\"recall\"]:\n",
    "                    result[\"recall\"][approach] = {}\n",
    "                for method_key in trial_result[approach]:\n",
    "                    if method_key not in result[\"recall\"][approach]:\n",
    "                        result[\"recall\"][approach][method_key] = []\n",
    "                    result[\"recall\"][approach][method_key].append(trial_result[approach][method_key][\"recall\"])\n",
    "        \n",
    "        for approach in prompting_approaches:\n",
    "            # print(f\"{relation_results['relation_name']} -- {approach}\")\n",
    "            for method_key in result[\"recall\"][approach]:\n",
    "                # print(f\" ====> {method_key} | {len(result['recall'][approach][method_key])}\")\n",
    "                try:\n",
    "                    current_results = remove_none(result[\"recall\"][approach][method_key])\n",
    "                    result[\"recall\"][approach][method_key] = np.array(current_results).mean(axis = 0)\n",
    "                except:\n",
    "                    print(result[\"recall\"][approach][method_key])\n",
    "                    raise Exception(\"ValueError\")\n",
    "        \n",
    "        results_formatted[relation_results[\"relation_name\"]] = result\n",
    "\n",
    "    return results_formatted\n",
    "\n",
    "results_formatted = format_results(results_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "relations_by_name = {r.name: r for r in dataset.relations}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "def segregate_categorywise(\n",
    "    results_formatted: dict,\n",
    "    property_key: Literal[\"relation_type\", \"fn_type\", \"disambiguating\", \"symmetric\"] = \"relation_type\"\n",
    ") -> dict:\n",
    "    performance_category_wise = {}\n",
    "    for relation_name in results_formatted:\n",
    "        property_value = relations_by_name[relation_name].properties.__dict__[property_key]\n",
    "        # print(f\"{relation_name} : {property_value}\")\n",
    "        result = results_formatted[relation_name]\n",
    "        if property_value not in performance_category_wise:\n",
    "            performance_category_wise[property_value] = {}\n",
    "        for prompting in result[\"recall\"]:\n",
    "            if prompting not in performance_category_wise[property_value]:\n",
    "                performance_category_wise[property_value][prompting] = {}\n",
    "\n",
    "            for method in result[\"recall\"][prompting]:\n",
    "                if method not in performance_category_wise[property_value][prompting]:\n",
    "                    performance_category_wise[property_value][prompting][method] = []\n",
    "                performance_category_wise[property_value][prompting][method].append(result[\"recall\"][prompting][method])\n",
    "        \n",
    "\n",
    "    for property_value in performance_category_wise:\n",
    "        for prompting in performance_category_wise[property_value]:\n",
    "            for method in performance_category_wise[property_value][prompting]:\n",
    "                performance_category_wise[property_value][prompting][method] = np.array(performance_category_wise[property_value][prompting][method]).mean(axis = 0)\n",
    "    \n",
    "    return performance_category_wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "performance_category_wise = segregate_categorywise(results_formatted, \"relation_type\")\n",
    "performance_category_wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_order = ['factual', 'linguistic', 'bias', 'commonsense']\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 35+5\n",
    "MEDIUM_SIZE = 50\n",
    "BIGGER_SIZE = 55+5\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "# prompting_colors = {\"zero_shot\": \"deepskyblue\", \"icl\": \"darkblue\"}\n",
    "prompting_colors = {\"zero_shot\": \"lightsteelblue\", \"icl\": \"steelblue\"}\n",
    "prompting_alpha = {\"zero_shot\": 0.7, \"icl\": 1}\n",
    "\n",
    "method_name_dict  = {\n",
    "    \"logit_lens\": \"Logit Lens\",\n",
    "    \"corner\": \"Corner Translation\",\n",
    "    \"translation\": \"Translation\",\n",
    "    \"learned_linear\": \"Linear Regression\",\n",
    "    \"lre_emb\": \"R(e)\",\n",
    "    \"lre\": \"R(s)\",\n",
    "    \"corner_lre\": \"corner + LRE\",\n",
    "}\n",
    "method_color  = {\n",
    "    \"logit_lens\": \"darkred\",\n",
    "    \"corner\": \"black\",\n",
    "    \"learned_linear\": \"olive\",\n",
    "    \"lre_emb\": \"lightsteelblue\",\n",
    "    \"lre\": \"darkblue\",\n",
    "    \"translation\": \"purple\"\n",
    "}\n",
    "\n",
    "method_order = [\n",
    "    \"lre\", \"lre_emb\", \n",
    "    \"learned_linear\", \n",
    "    # \"corner\", \n",
    "    \"translation\", \n",
    "    \"logit_lens\"\n",
    "][::-1]\n",
    "#####################################################################################\n",
    "\n",
    "\n",
    "def plot_categorywise(canvas, result, title, set_yticks = True, separate_prompting = True):\n",
    "    bar_width = 0.4\n",
    "    if separate_prompting:\n",
    "        idx = 0\n",
    "        for prompting in result:\n",
    "            recalls = [result[prompting][method][0] for method in method_order]\n",
    "            canvas.barh(\n",
    "                np.arange(len(recalls)) + idx * bar_width, recalls, \n",
    "                height = bar_width, label = prompting.capitalize(), \n",
    "                color = prompting_colors[prompting],\n",
    "                edgecolor = \"black\",\n",
    "                alpha = prompting_alpha[prompting]\n",
    "            )\n",
    "            idx += 1\n",
    "    else:\n",
    "        prompting = \"icl\"\n",
    "        recalls = [result[prompting][method][0] for method in method_order]\n",
    "        canvas.barh(\n",
    "            np.arange(len(recalls)), recalls, \n",
    "            height = bar_width*2, \n",
    "            color = \"steelblue\", #[method_color[method] for method in method_order],\n",
    "            edgecolor = \"black\",\n",
    "            alpha = 1\n",
    "        )\n",
    "    \n",
    "    canvas.set_xlim(0, 1)\n",
    "    canvas.set_title(title.capitalize(), fontsize = BIGGER_SIZE)\n",
    "\n",
    "    if separate_prompting == False:\n",
    "        canvas.set_yticks(np.arange(len(method_order)))\n",
    "    else:    \n",
    "        canvas.set_yticks(np.arange(len(method_order)) + 0.5 * bar_width)\n",
    "    if set_yticks:\n",
    "        canvas.set_yticklabels([method_name_dict[method] for method in method_order])\n",
    "    else:\n",
    "        canvas.set_yticklabels([\"\"]*len(method_order))\n",
    "    canvas.set_xticks(np.linspace(0, 1, 5))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_subplots = len(performance_category_wise)\n",
    "ncols=4\n",
    "nrows=int(np.ceil(n_subplots / ncols))\n",
    "\n",
    "fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 9))\n",
    "if nrows == 1:\n",
    "    axes = [axes]\n",
    "if ncols == 1:\n",
    "    axes = [[ax] for ax in axes]\n",
    "\n",
    "ax_col, ax_row = 0, 0\n",
    "# for i, (category, result) in enumerate(performance_category_wise.items()):\n",
    "for category in category_order:\n",
    "    result = performance_category_wise[category]\n",
    "    plot_categorywise(canvas = axes[ax_row][ax_col], result = result, title = category, set_yticks = ax_col == 0)\n",
    "    ax_col += 1\n",
    "    if ax_col == ncols:\n",
    "        ax_col = 0\n",
    "        ax_row += 1\n",
    "    # break\n",
    "\n",
    "plt.savefig(f\"{fig_dir}/{model_name}-faithfulness_baselines_prompting.pdf\", bbox_inches=\"tight\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib as mpl\n",
    "palette = list(prompting_colors.keys())[::-1]\n",
    "desc = {\"zero_shot\": \"Without few-shot examples\", \"icl\": \"With few-shot examples\"}\n",
    "handles = [\n",
    "    mpl.patches.Patch(\n",
    "        facecolor=prompting_colors[x], label=desc[x], edgecolor=\"black\", alpha=prompting_alpha[x]) \n",
    "        for x in palette\n",
    "    ]\n",
    "# Create legend\n",
    "plt.legend(handles=handles, ncols=2)\n",
    "# Get current axes object and turn off axis\n",
    "plt.gca().set_axis_off()\n",
    "plt.savefig(f\"{fig_dir}/legend_faithfulness_prompting.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_subplots = len(performance_category_wise)\n",
    "ncols=4\n",
    "nrows=int(np.ceil(n_subplots / ncols))\n",
    "\n",
    "fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 9))\n",
    "if nrows == 1:\n",
    "    axes = [axes]\n",
    "if ncols == 1:\n",
    "    axes = [[ax] for ax in axes]\n",
    "\n",
    "ax_col, ax_row = 0, 0\n",
    "# for i, (category, result) in enumerate(performance_category_wise.items()):\n",
    "for category in category_order:\n",
    "    result = performance_category_wise[category]\n",
    "    plot_categorywise(\n",
    "        canvas = axes[ax_row][ax_col], result = result, \n",
    "        title = category, set_yticks = ax_col == 0,\n",
    "        separate_prompting=False    \n",
    "    )\n",
    "    ax_col += 1\n",
    "    if ax_col == ncols:\n",
    "        ax_col = 0\n",
    "        ax_row += 1\n",
    "    # break\n",
    "\n",
    "plt.savefig(f\"{fig_dir}/{model_name}-faithfulness_baselines.pdf\", bbox_inches=\"tight\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = [\n",
    "    \"gpt2-xl\", \n",
    "    \"gptj\", \n",
    "    \"llama-13b\"\n",
    "]\n",
    "categorywise_results = {}\n",
    "\n",
    "for model_name in MODELS:\n",
    "    print(model_name)\n",
    "    results_raw = load_raw_results(\n",
    "        model_name, results_path=path, \n",
    "        multiple_files=False\n",
    "    )\n",
    "    results_formatted = format_results(results_raw)\n",
    "\n",
    "    category_wise = segregate_categorywise(results_formatted, \"relation_type\")\n",
    "    for category in category_wise:\n",
    "        category_wise[category] = category_wise[category]['icl'][\"lre\"][0]\n",
    "    \n",
    "    for category in category_wise:\n",
    "        if category not in categorywise_results:\n",
    "            categorywise_results[category] = {}\n",
    "        categorywise_results[category][model_name] = category_wise[category]\n",
    "\n",
    "categorywise_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 16\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+5)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "\n",
    "n_subplots = len(categorywise_results)\n",
    "ncols=n_subplots\n",
    "nrows=int(np.ceil(n_subplots / ncols))\n",
    "\n",
    "category_order = [\"factual\", \"linguistic\", \"commonsense\", \"bias\"]\n",
    "models = {\n",
    "    \"gpt2-xl\": \"GPT2-xl\", \n",
    "    \"gptj\": \"GPT-J\", \n",
    "    \"llama-13b\": \"LLaMa-13b\"\n",
    "}\n",
    "\n",
    "# model_colors = {\n",
    "#     \"gpt2-xl\": \"khaki\", \n",
    "#     \"gptj\": \"darkseagreen\", \n",
    "#     \"llama-13b\": \"teal\"\n",
    "# }\n",
    "\n",
    "model_colors = {\n",
    "    \"gpt2-xl\": \"khaki\", \n",
    "    \"gptj\": \"darkseagreen\", \n",
    "    \"llama-13b\": \"lightblue\"\n",
    "}\n",
    "\n",
    "\n",
    "idx = 0\n",
    "bar_width = 0.225\n",
    "for model in models:\n",
    "    recalls = []\n",
    "    for category in category_order:\n",
    "        recalls.append(categorywise_results[category][model])\n",
    "    \n",
    "    plt.bar(\n",
    "        np.arange(len(recalls)) + idx * bar_width, recalls,\n",
    "        width = bar_width,\n",
    "        label = models[model],\n",
    "        edgecolor = \"black\",\n",
    "        color = model_colors[model],\n",
    "        alpha = 0.99\n",
    "    )\n",
    "    idx += 1\n",
    "\n",
    "plt.xticks(np.arange(len(recalls)) + bar_width, [cat.capitalize() for cat in category_order])\n",
    "plt.ylabel(\"Faithfulness\")\n",
    "plt.legend(ncol = 3, bbox_to_anchor=(0.5, -.25), loc='lower center', frameon=False)\n",
    "plt.savefig(f\"figs/faithfulness_lre_models.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 25\n",
    "BIGGER_SIZE = 28\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"gpt2-xl\": \"GPT2-xl\", \n",
    "    \"gptj\": \"GPT-J\", \n",
    "    \"llama-13b\": \"LLaMa-13b\"\n",
    "}\n",
    "\n",
    "model_name = \"llama-13b\"\n",
    "results_raw = load_raw_results(\"llama-13b\", results_path=\"results/faithfulne\")\n",
    "results_formatted = format_results(results_raw)\n",
    "\n",
    "relation_and_recall = []\n",
    "\n",
    "for relation_name in results_formatted:\n",
    "    relation_and_recall.append({\n",
    "        \"relation\": relation_name,\n",
    "        \"recall@1\": results_formatted[relation_name][\"recall\"]['icl'][\"lre\"][0]\n",
    "    })\n",
    "\n",
    "relation_and_recall = sorted(relation_and_recall, key = lambda x: x[\"recall@1\"])\n",
    "\n",
    "relations = [r[\"relation\"] for r in relation_and_recall]\n",
    "recalls = [r[\"recall@1\"] for r in relation_and_recall]\n",
    "\n",
    "plt.figure(figsize = (10, 20))\n",
    "plt.barh(np.arange(len(relations)), recalls, color = \"teal\", alpha = 0.6)\n",
    "plt.yticks(np.arange(len(relations)), relations)\n",
    "plt.xticks(np.linspace(0, 1, 11), [np.round(v, 1) for v in np.linspace(0, 1, 11)])\n",
    "plt.ylim(-0.7,len(relation_and_recall)-.3)\n",
    "plt.xlabel(\"Recall@1\")\n",
    "plt.xlim(0, 1)\n",
    "\n",
    "for x_tick in np.linspace(0, 1, 11):\n",
    "    plt.axvline(x_tick, color = \"black\", alpha = 0.2)\n",
    "\n",
    "plt.title(f\"LRE faithfulness in \", x = 0.3, pad=15)\n",
    "# plt.savefig(f\"{fig_dir}/faithfulness_lre_relationwise.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
