{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "relations_by_name = {r.name: r for r in dataset.relations}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "import pandas as pd\n",
    "\n",
    "def segregate_table_results_categorywise(\n",
    "    results_df: pd.DataFrame | dict,\n",
    "    property_key: Literal[\"relation_type\", \"fn_type\", \"disambiguating\", \"symmetric\"] = \"relation_type\",\n",
    "    metric: Literal[\"efficacy\", \"recall@1\"] = \"efficacy\"\n",
    ") -> dict:\n",
    "    if isinstance(results_df, pd.DataFrame):\n",
    "        results_df = {relation[\"relation\"]:relation for relation in results_df.to_dict(orient=\"records\")}\n",
    "\n",
    "    # performance_category_wise = {k: {} for k in [\"factual\", \"linguistic\", \"commonsense\", \"bias\"]}\n",
    "    performance_category_wise = {}\n",
    "    for relation_name in results_df.keys():\n",
    "        property_value = relations_by_name[relation_name].properties.__dict__[property_key]\n",
    "        result = results_df[relation_name]\n",
    "        if property_value not in performance_category_wise:\n",
    "            performance_category_wise[property_value] = []\n",
    "        performance_category_wise[property_value].append(float(result[metric].split()[0]))\n",
    "\n",
    "    for property_value in performance_category_wise:\n",
    "        performance_category_wise[property_value] = np.array(performance_category_wise[property_value]).mean(axis = 0)\n",
    "    \n",
    "    return performance_category_wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table_path = \"../../results/tables\"\n",
    "\n",
    "MODELS = [\n",
    "    \"gpt2-xl\", \n",
    "    \"gptj\", \n",
    "    \"llama-13b\"\n",
    "]\n",
    "categorywise_results = {k : {} for k in [\"factual\", \"linguistic\", \"commonsense\", \"bias\"]}\n",
    "\n",
    "for model_name in MODELS:\n",
    "    df = pd.read_csv(f\"{table_path}/{model_name}-hparams.csv\")\n",
    "    model_results = segregate_table_results_categorywise(\n",
    "        df, \n",
    "        property_key = \"relation_type\", \n",
    "        metric = \"efficacy\"\n",
    "    )\n",
    "\n",
    "    for category in categorywise_results:\n",
    "        categorywise_results[category][model_name] = model_results[category]\n",
    "\n",
    "categorywise_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 16\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+5)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "\n",
    "n_subplots = len(categorywise_results)\n",
    "ncols=n_subplots\n",
    "nrows=int(np.ceil(n_subplots / ncols))\n",
    "\n",
    "category_order = [\"factual\", \"linguistic\", \"commonsense\", \"bias\"]\n",
    "models = {\n",
    "    \"gpt2-xl\": \"GPT2-xl\", \n",
    "    \"gptj\": \"GPT-J\", \n",
    "    \"llama-13b\": \"LLaMa-13b\"\n",
    "}\n",
    "model_colors = {\n",
    "    \"gpt2-xl\": \"khaki\", \n",
    "    \"gptj\": \"darkseagreen\", \n",
    "    \"llama-13b\": \"lightblue\"\n",
    "}\n",
    "\n",
    "idx = 0\n",
    "bar_width = 0.225\n",
    "for model in models:\n",
    "    recalls = []\n",
    "    for category in category_order:\n",
    "        recalls.append(categorywise_results[category][model])\n",
    "    \n",
    "    plt.bar(\n",
    "        np.arange(len(recalls)) + idx * bar_width, recalls,\n",
    "        width = bar_width,\n",
    "        label = models[model],\n",
    "        edgecolor = \"black\",\n",
    "        color = model_colors[model],\n",
    "        alpha = 0.99\n",
    "    )\n",
    "    idx += 1\n",
    "\n",
    "plt.xticks(np.arange(len(recalls)) + bar_width, [cat.capitalize() for cat in category_order])\n",
    "plt.ylabel(\"Causality\")\n",
    "plt.legend(ncol = 3, bbox_to_anchor=(0.5, -.25), loc='lower center', frameon=False)\n",
    "plt.savefig(f\"figs/efficacy_lre_models.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_efficacy_baseline_results, format_efficacy_baseline_results\n",
    "############################################\n",
    "efficacy_root = \"../../results/efficacy_baselines-24-trials\"\n",
    "############################################\"\"\n",
    "\n",
    "efficacy_path = f\"{efficacy_root}/gptj\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "efficacy_results_raw = read_efficacy_baseline_results(efficacy_path)\n",
    "list(efficacy_results_raw.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "efficacy_results = {\n",
    "    relation_name: format_efficacy_baseline_results(efficacy_result)\n",
    "    for relation_name, efficacy_result in efficacy_results_raw.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hparam_table = pd.read_csv(\"../../results/tables/gptj-hparams.csv\")\n",
    "hparam_table[hparam_table[\"relation\"] == \"country capital city\"][\"layer\"].values[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "def segregate_categorywise(\n",
    "    results_formatted: dict,\n",
    "    hparam_table: pd.DataFrame,\n",
    "    property_key: Literal[\"relation_type\", \"fn_type\", \"disambiguating\", \"symmetric\"] = \"relation_type\"\n",
    ") -> dict:\n",
    "    performance_category_wise = {}\n",
    "    for relation_name in results_formatted:\n",
    "        layer = hparam_table[hparam_table[\"relation\"] == relation_name][\"layer\"].values[0]\n",
    "        layer = int(layer) if layer != \"emb\" else layer\n",
    "        property_value = relations_by_name[relation_name].properties.__dict__[property_key]\n",
    "        efficacy_result = results_formatted[relation_name][\"layerwise_result\"][layer]\n",
    "        if property_value not in performance_category_wise:\n",
    "            performance_category_wise[property_value] = {}\n",
    "\n",
    "        for edit_type in efficacy_result:\n",
    "            if edit_type not in performance_category_wise[property_value]:\n",
    "                performance_category_wise[property_value][edit_type] = []\n",
    "            performance_category_wise[property_value][edit_type].append(efficacy_result[edit_type].mean)\n",
    "        \n",
    "\n",
    "    for property_value in performance_category_wise:\n",
    "        for edit_type in performance_category_wise[property_value]:\n",
    "            performance_category_wise[property_value][edit_type] = np.array(performance_category_wise[property_value][edit_type]).mean(axis = 0)\n",
    "    \n",
    "    return performance_category_wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "efficacy_category_wise = segregate_categorywise(\n",
    "    efficacy_results, \n",
    "    hparam_table=hparam_table\n",
    ")\n",
    "\n",
    "efficacy_category_wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_order = ['factual', 'linguistic', 'bias', 'commonsense']\n",
    "#####################################################################################\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 35+5\n",
    "MEDIUM_SIZE = 50\n",
    "BIGGER_SIZE = 55+5\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "edit_type_legends = {\n",
    "    \"low_rank_pinv\": \"Causality\",\n",
    "    \"hidden_baseline\"   : \"Replacing with s′\",\n",
    "    \"embed_baseline\"    : \"Replacing with e\",\n",
    "    \"hidden_baseline_z\" : \"Replacing with o′\"\n",
    "}\n",
    "\n",
    "color_dict = {\n",
    "    \"low_rank_pinv\": \"darkorange\",\n",
    "    \"hidden_baseline\": \"darkgreen\",\n",
    "    \"embed_baseline\": \"darkred\",\n",
    "    \"hidden_baseline_z\": \"blue\"\n",
    "}\n",
    "\n",
    "edit_types = [\"low_rank_pinv\", \"hidden_baseline\", \"embed_baseline\", \"hidden_baseline_z\"][::-1]\n",
    "\n",
    "#####################################################################################\n",
    "\n",
    "\n",
    "def plot_categorywise(canvas, result, title, set_yticks = True):\n",
    "    bar_width = 0.8\n",
    "    idx = 0\n",
    "\n",
    "    causality_scores = [result[edit] for edit in edit_types]\n",
    "    canvas.barh(\n",
    "        np.arange(len(recalls)), causality_scores, \n",
    "        height = bar_width,\n",
    "        color = \"steelblue\", #\"#cc7a00\", #[color_dict[edit] for edit in edit_types],\n",
    "        edgecolor = \"black\", # linewidth = 2,\n",
    "        alpha = 1 # 0.7\n",
    "    )    \n",
    "    canvas.set_xlim(0, 1)\n",
    "    canvas.set_title(title.capitalize(), fontsize = BIGGER_SIZE)\n",
    "\n",
    "    canvas.set_yticks(np.arange(len(edit_types)))\n",
    "    if set_yticks:\n",
    "        canvas.set_yticklabels([edit_type_legends[edit] for edit in edit_types])\n",
    "    else:\n",
    "        canvas.set_yticklabels([\"\"] * len(edit_types))\n",
    "    canvas.set_xticks(np.linspace(0, 1, 5))        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_subplots = len(efficacy_category_wise)\n",
    "ncols=4\n",
    "nrows=int(np.ceil(n_subplots / ncols))\n",
    "\n",
    "fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 8))\n",
    "if nrows == 1:\n",
    "    axes = [axes]\n",
    "if ncols == 1:\n",
    "    axes = [[ax] for ax in axes]\n",
    "\n",
    "ax_col, ax_row = 0, 0\n",
    "\n",
    "for category in category_order:\n",
    "    plot_categorywise(axes[ax_row][ax_col], efficacy_category_wise[category], category, set_yticks=ax_col == 0)\n",
    "    ax_col += 1\n",
    "    if ax_col == ncols:\n",
    "        ax_col = 0\n",
    "        ax_row += 1\n",
    "\n",
    "plt.savefig(f\"figs/gptj/causality_baselines.pdf\", bbox_inches=\"tight\")\n",
    "fig.tight_layout()\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "9*4/5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def export_legend(legend, filename=\"legend.pdf\"):\n",
    "    fig  = legend.figure\n",
    "    fig.canvas.draw()\n",
    "    bbox  = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())\n",
    "    fig.savefig(filename, dpi=\"figure\", bbox_inches=bbox)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_name = \"work location\"\n",
    "color_dict = {\n",
    "    \"low_rank_pinv\": \"darkorange\",\n",
    "    \"hidden_baseline\": \"darkgreen\",\n",
    "    \"embed_baseline\": \"darkred\",\n",
    "    \"hidden_baseline_z\": \"blue\"\n",
    "}\n",
    "\n",
    "edit_type_legends = {\n",
    "    \"low_rank_pinv\": \"Efficacy\",\n",
    "    \"hidden_baseline\"   : \"Replacing with s′\",\n",
    "    \"embed_baseline\"    : \"Replacing with e\",\n",
    "    \"hidden_baseline_z\" : \"Replacing with o′\"\n",
    "}\n",
    "\n",
    "def plot_efficacy_baseline(\n",
    "        ax, efficacy_result, \n",
    "        show_legend = True,\n",
    "        export_legend_to_file = None, \n",
    "    ):\n",
    "    layerwise_result = efficacy_result[\"layerwise_result\"]\n",
    "    layers = list(layerwise_result.keys())\n",
    "    # edit_types = list(layerwise_result[layers[0]].keys())\n",
    "    edit_types = [\"low_rank_pinv\", \"hidden_baseline\", \"embed_baseline\", \"hidden_baseline_z\"]\n",
    "\n",
    "    for edit_type in edit_types:\n",
    "        if(edit_type not in color_dict):\n",
    "            continue\n",
    "        ax.plot(\n",
    "            range(len(layers)),\n",
    "            [layerwise_result[layer][edit_type].mean for layer in layers],\n",
    "            label=edit_type_legends[edit_type],\n",
    "            color=color_dict[edit_type],\n",
    "            alpha=1,\n",
    "            linestyle='--' if edit_type != \"low_rank_pinv\" else '-',\n",
    "            linewidth=1 if edit_type != \"low_rank_pinv\" else 2\n",
    "        )\n",
    "        ax.fill_between(\n",
    "            range(len(layers)),\n",
    "            [layerwise_result[layer][edit_type].mean - layerwise_result[layer][edit_type].stdev for layer in layers],\n",
    "            [layerwise_result[layer][edit_type].mean + layerwise_result[layer][edit_type].stdev for layer in layers],\n",
    "            alpha=0.1,\n",
    "            color=color_dict[edit_type]\n",
    "        )\n",
    "    ax.set_xticks(range(len(layers)), layers, rotation=90)\n",
    "    ax.set_xticklabels(layers)\n",
    "    ax.set_ylim(0, 1)\n",
    "    if show_legend:\n",
    "        ax.legend()\n",
    "    if export_legend_to_file is not None:\n",
    "        legend = ax.legend(\n",
    "            ncol = len(edit_types), bbox_to_anchor=(4.7, -0.3), \n",
    "            frameon=False, fontsize=20\n",
    "        )\n",
    "        export_legend(legend, export_legend_to_file)\n",
    "        legend.remove()\n",
    "\n",
    "    ax.set_xlabel(\"Layer\")\n",
    "    ax.set_ylabel(\"Success@1\")\n",
    "    if(ax.get_title() == \"\"):\n",
    "        ax.set_title(efficacy_result[\"relation_name\"])\n",
    "\n",
    "\n",
    "\n",
    "efficacy_result = format_efficacy_baseline_results(\n",
    "    efficacy_results_raw[relation_name]\n",
    ")\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 5))\n",
    "plot_efficacy_baseline(ax, efficacy_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_names = [\n",
    "    \"work location\", \"task person type\", \"city in country\", \"person lead singer of band\"\n",
    "]\n",
    "\n",
    "filtered_results = {}\n",
    "for relation_name in relation_names:\n",
    "    raw_result = efficacy_results_raw[relation_name] \n",
    "    filtered_results[relation_name] = format_efficacy_baseline_results(raw_result)\n",
    "\n",
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 12\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 22\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "n_cols = 4\n",
    "n_rows = int(np.ceil(len(relation_names)/n_cols))\n",
    "fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols * 4, n_rows * 3.5))\n",
    "\n",
    "axes = [axes] if n_rows == 1 else axes\n",
    "axes = [axes] if n_cols == 1 else axes\n",
    "\n",
    "layers = list(efficacy_result[\"layerwise_result\"].keys())\n",
    "selected_layers = [layers[0]] + layers[1::3][1:]\n",
    "\n",
    "layer_labels = []\n",
    "for layer in layers:\n",
    "    if layer in selected_layers:\n",
    "        layer_labels.append(layer)\n",
    "    else:\n",
    "        layer_labels.append(\"\")\n",
    "\n",
    "cur_row = 0\n",
    "cur_col = 0\n",
    "for relation_name in relation_names:\n",
    "    relation_result = filtered_results[relation_name]\n",
    "    ax = axes[cur_row][cur_col]\n",
    "    print(relation_name, type(ax))\n",
    "    plot_efficacy_baseline(\n",
    "        ax, filtered_results[relation_name], \n",
    "        show_legend=False, \n",
    "        export_legend_to_file=\"figs/gptj/legend-causality-baselines.pdf\" if cur_col == 0 else None,\n",
    "    )\n",
    "    ax.set_title(relation_result['relation_name'], fontsize=BIGGER_SIZE)\n",
    "    ax.set_xticks(range(len(layers)), layer_labels)\n",
    "    if cur_col == 0:\n",
    "        ax.set_ylabel(\"Success@1\", fontsize=BIGGER_SIZE)\n",
    "    else:\n",
    "        ax.set_ylabel(\"\")\n",
    "    ax.set_xlabel(\"\")\n",
    "\n",
    "    cur_col += 1\n",
    "    if cur_col == n_cols:\n",
    "        cur_row += 1\n",
    "        cur_col = 0\n",
    "\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig(f\"figs/gptj/layer-wise-causality-baseline.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
