{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from common_imports import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "import matplotlib.pyplot as plt\n",
    "FIGURES_ROOT = Path('figures/')\n",
    "EXP_OUTPUT_ROOT = Path('patching_exp_outputs/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading and preprocessing results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results = joblib.load(EXP_OUTPUT_ROOT / 'fact_patching_results.joblib')\n",
    "best_vs, norm_metrics_df, patching_metrics_df, exp_facts = patching_results\n",
    "\n",
    "def add_logitdiffs_as_fractions_of_clean(df: pd.DataFrame):\n",
    "    layer_and_idx_to_clean_ld = {}\n",
    "    for layer in df['layer'].unique():\n",
    "        for fact_idx in df['fact_idx'].unique():\n",
    "            clean_ld = df.query(f'layer == {layer} and fact_idx == {fact_idx} and method == \"clean\"')['ld'].values[0]\n",
    "            layer_and_idx_to_clean_ld[(layer, fact_idx)] = clean_ld\n",
    "    # now apply the transformation on rows\n",
    "    df['ld_fraction'] = df.apply(lambda row: row['ld'] / layer_and_idx_to_clean_ld[(row['layer'], row['fact_idx'])], axis=1)\n",
    "    return df\n",
    "add_logitdiffs_as_fractions_of_clean(df=patching_metrics_df)\n",
    "\n",
    "\n",
    "editing_result_rows = joblib.load(EXP_OUTPUT_ROOT / 'rank1_editing_results.joblib')\n",
    "correlation_df = pd.DataFrame([\n",
    "    {'layer': d['layer'], \n",
    "     'fact_idx': d['fact_idx'],\n",
    "     'act_diff_sim_to_row': np.abs(d['act_diff_sim_to_row']),\n",
    "     'act_diff_sim_to_null': np.abs(d['act_diff_sim_to_null']),\n",
    "    } for d in editing_result_rows\n",
    "])\n",
    "\n",
    "def get_rank1_metrics(editing_result_rows: List[dict]):\n",
    "    result_rows = []\n",
    "    for d in editing_result_rows:\n",
    "        result_rows.append({\n",
    "            'method': 'rome',\n",
    "            'ld': d['edit_logit_diff'],\n",
    "            'predict_source': d['edit_predicts_source'],\n",
    "            'predict_base': d['edit_predicts_base'],\n",
    "            'prediction': d['edit_prediction'],\n",
    "            'layer': d['layer'],\n",
    "            'fact_idx': d['fact_idx'],\n",
    "        })\n",
    "    return pd.DataFrame(result_rows)\n",
    "\n",
    "def convert_fact_patch_example_to_rome_request(\n",
    "    fp_example: dict) -> dict:\n",
    "    request = {\n",
    "        \"prompt\": fp_example['prompt'],\n",
    "        \"subject\": fp_example['base_subject'],\n",
    "        \"target_new\": {\"str\": fp_example['source_target']},\n",
    "    }\n",
    "    return request\n",
    "\n",
    "def dump_rome_requests(exp_facts: List[dict]):\n",
    "    rome_requests = [convert_fact_patch_example_to_rome_request(fp_example) for fp_example in exp_facts]\n",
    "    with open('rome_requests.json', 'w') as f:\n",
    "        json.dump(rome_requests, f)\n",
    "dump_rome_requests(exp_facts)\n",
    "\n",
    "def load_rome_edits_from_fact_patches():\n",
    "    \"\"\"\n",
    "    Return a list of dictionaries where each entry is of the form \n",
    "    {'fact_idx': i, 'layer': j, 'u': u, 'v': v}\n",
    "    for u, v being ROME left/right edit vectors \n",
    "    \"\"\"\n",
    "    fact_patching_rome_edit_rows = joblib.load('fact_patching_rome_edit_rows.joblib')\n",
    "    return fact_patching_rome_edit_rows\n",
    "fact_patching_rome_edit_rows = load_rome_edits_from_fact_patches()\n",
    "\n",
    "patching_metrics_rank1 = get_rank1_metrics(editing_result_rows)\n",
    "das_and_rank1_df = pd.merge(left=patching_metrics_rank1, \n",
    "         right=patching_metrics_df.query('method == \"das\"'),\n",
    "         on=['layer', 'fact_idx'], suffixes=('_rank1', '_das'))\n",
    "das_and_rank1_df['ld'] = das_and_rank1_df['ld_rank1'] - das_and_rank1_df['ld_das']\n",
    "das_and_rank1_df['predict_source'] = das_and_rank1_df['predict_source_rank1']\n",
    "\n",
    "das_and_rank1_ld_df = pd.melt(das_and_rank1_df, id_vars=['layer', 'fact_idx', ], value_vars=['ld_rank1', 'ld_das'], var_name='method', value_name='logit_diff')\n",
    "das_and_rank1_acc_df = pd.melt(das_and_rank1_df, id_vars=['layer', 'fact_idx', ], value_vars=['predict_source_rank1', 'predict_source_das'], var_name='method', value_name='fact_flip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get editing results for layer 5\n",
    "rows_layer_5 = [d for d in editing_result_rows if d['layer'] == 5]\n",
    "# put them in a list ordered by fact_idx\n",
    "rows_layer_5 = sorted(rows_layer_5, key=lambda d: d['fact_idx'])\n",
    "# collect the as and bs\n",
    "a_vectors = [d['a'] for d in rows_layer_5]\n",
    "b_vectors = [d['b'] for d in rows_layer_5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "us, vs = joblib.load('vectors_layer_5.joblib')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "us[0].shape, vs[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the cosine similarities between the a_vectors and the vs\n",
    "idx_permutation = np.random.permutation(len(vs)).tolist()\n",
    "vs_random_perm = [vs[i] for i in idx_permutation]\n",
    "a_vs_sims = [torch.cosine_similarity(torch.tensor(a_vector).cuda(), torch.tensor(v).cuda(), 0) for a_vector, v in zip(a_vectors, vs_random_perm)]\n",
    "b_us_sims = [torch.cosine_similarity(torch.tensor(b_vector).cuda(), torch.tensor(u).cuda(), 0) for b_vector, u in zip(b_vectors, us)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_vectors[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating fact patching figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "### Bar plot of fraction of facts flipped by das, rowspace and full mlp patch\n",
    "################################################################################\n",
    "df = patching_metrics_df.query('method in [\"row\", \"das\", \"full_mlp\"]') \n",
    "df['method'] = df['method'].map({'row': 'DAS rowspace component', 'das': 'DAS', 'full_mlp': 'Full MLP patch'})\n",
    "plt.figure(figsize=(12, 8))\n",
    "sns.barplot(data=df, x='layer', y='predict_source', hue='method')\n",
    "# remove the x, y axis titles\n",
    "plt.xlabel('')\n",
    "plt.ylabel('')\n",
    "# increase the font of the x and y ticks\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "# increase legend font size\n",
    "# set legend title\n",
    "plt.legend(title='Method', title_fontsize=24, fontsize=20)\n",
    "# plt.show()\n",
    "# save the figure to disk\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-barplot.pdf', format='pdf', bbox_inches='tight')\n",
    "\n",
    "################################################################################\n",
    "### violin plot of the norm of the nullspace component of the patch\n",
    "################################################################################\n",
    "plt.figure(figsize=(12, 8))\n",
    "ax = sns.violinplot(x=\"layer\", y=\"null_norm\", data=norm_metrics_df, inner=None,)\n",
    "# indicate the fraction of points for a given layer where the edit predicts the\n",
    "# source\n",
    "# sns.stripplot(x=\"layer\", y=\"ld\", hue=\"predict_source\", data=df, jitter=True, dodge=True)\n",
    "# add a horizontal line at 0\n",
    "plt.axhline(1, color='black', linestyle='--')\n",
    "# add a vertical line at the layer where the edit predicts the source most often\n",
    "# plt.axvline(df.groupby('layer')['predict_source'].mean().idxmax(), color='black', linestyle='--')\n",
    "# put legend in LOWER right\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "# set y limits to (0, 1 \n",
    "plt.ylim(0, 1)\n",
    "# increase legend font size\n",
    "# set legend title\n",
    "# plt.title('Logit difference between source and base predictions', fontsize=24)\n",
    "plt.xlabel(None)\n",
    "plt.ylabel(None)\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-norm-violins.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "################################################################################\n",
    "### violin plots of logit difference distribution for each layer for each method\n",
    "################################################################################\n",
    "plt.figure(figsize=(12, 8))\n",
    "df = patching_metrics_df.query('method in [\"row\", \"das\", \"full_mlp\"]') \n",
    "# create a violin plot of logit difference on the y axis and layer on the x axis\n",
    "# no hue\n",
    "df['method'] = df['method'].map({'row': 'DAS rowspace component', 'das': 'DAS', 'full_mlp': 'Full MLP patch'})\n",
    "ax = sns.violinplot(x=\"layer\", y=\"ld_fraction\", data=df, inner=None, hue='method')\n",
    "# indicate the fraction of points for a given layer where the edit predicts the\n",
    "# source\n",
    "# sns.stripplot(x=\"layer\", y=\"ld\", hue=\"predict_source\", data=df, jitter=True, dodge=True)\n",
    "# add a horizontal line at 0\n",
    "plt.axhline(0, color='black', linestyle='--')\n",
    "# add a vertical line at the layer where the edit predicts the source most often\n",
    "# plt.axvline(df.groupby('layer')['predict_source'].mean().idxmax(), color='black', linestyle='--')\n",
    "# put legend in LOWER right\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "# increase legend font size\n",
    "# set legend title\n",
    "plt.legend(title='Method', title_fontsize=24, fontsize=20, loc='lower right')\n",
    "# plt.title('Logit difference between source and base predictions', fontsize=24)\n",
    "plt.xlabel(None)\n",
    "plt.ylabel(None)\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-logitdiff-violins.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating rank-1 editing figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "### violin plot of the dot product between the nullspace component of the patch\n",
    "### and the difference in the activations of the source and target\n",
    "################################################################################\n",
    "plt.figure(figsize=(12, 8))\n",
    "ax = sns.violinplot(x=\"layer\", y=\"act_diff_sim_to_null\", data=correlation_df, inner=None,)\n",
    "# indicate the fraction of points for a given layer where the edit predicts the\n",
    "# source\n",
    "# sns.stripplot(x=\"layer\", y=\"ld\", hue=\"predict_source\", data=df, jitter=True, dodge=True)\n",
    "# add a horizontal line at 0\n",
    "plt.axhline(1, color='black', linestyle='--')\n",
    "# add a vertical line at the layer where the edit predicts the source most often\n",
    "# plt.axvline(df.groupby('layer')['predict_source'].mean().idxmax(), color='black', linestyle='--')\n",
    "# put legend in LOWER right\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "# set y limits to (0, 1 \n",
    "plt.ylim(0, 1)\n",
    "# increase legend font size\n",
    "# set legend title\n",
    "# plt.title('Logit difference between source and base predictions', fontsize=24)\n",
    "plt.xlabel(None)\n",
    "plt.ylabel(None)\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-correlation-violins.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "################################################################################\n",
    "### box plots comparing logit difference for rank1 edits and das patches by layer\n",
    "################################################################################\n",
    "plt.figure(figsize=(12, 8))\n",
    "das_and_rank1_ld_df['display_method'] = das_and_rank1_ld_df['method'].apply(lambda x: 'rank-1 model edit' if x == 'ld_rank1' else '1-dimensional patch')\n",
    "ax = sns.boxplot(x=\"layer\", y=\"logit_diff\", hue=\"display_method\", data=das_and_rank1_ld_df)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "plt.xlabel(None)\n",
    "plt.ylabel(None)\n",
    "plt.legend(title='Method', title_fontsize=24, fontsize=20)\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-rank1-vs-das-ld.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "### bar plot of interchange accuracy\n",
    "plt.figure(figsize=(12, 8))\n",
    "das_and_rank1_acc_df['display_method'] = das_and_rank1_acc_df['method'].apply(lambda x: 'rank-1 model edit' if x == 'predict_source_rank1' else '1-dimensional patch')\n",
    "ax = sns.barplot(x=\"layer\", y=\"fact_flip\", hue=\"display_method\", data=das_and_rank1_acc_df)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "plt.legend(title='Method', title_fontsize=24, fontsize=20)\n",
    "# set y limits to (0, 1 \n",
    "# increase legend font size\n",
    "# set legend title\n",
    "# plt.title('Logit difference between source and base predictions', fontsize=24)\n",
    "plt.xlabel(None)\n",
    "plt.ylabel(None)\n",
    "plt.savefig(FIGURES_ROOT / 'fact-patching-rank1-vs-das-acc.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fd8dbfdfd1a6a4c5f3a98a8b5f239185c4ac44e8c535538c941237e2ab93d1b0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
