{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager as font_manager\n",
    "import matplotlib.ticker as ticker\n",
    "from matplotlib.lines import Line2D\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from scipy.stats import skew\n",
    "from matplotlib.ticker import MaxNLocator\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Correlation between cost model and gpt4 score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('/home/ /Academics/Alignment-Low-Rank-Analysis/results/gpt4_judge/top0_guided_by_Llama-2-7b-hf_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_Llama-2-7b-hf_base_vs_dpo_on_beavertails_base_completion_table.csv', usecols=['cost_scores', 'gpt4_scores'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating a linear model plot\n",
    "g = sns.lmplot(x=\"cost_scores\", y=\"gpt4_scores\", data=df, aspect=1.3)\n",
    "\n",
    "# Setting axis labels with font size and using the updated font\n",
    "g.set_axis_labels(x_var='Cost Score', y_var='GPT4 Score', fontsize=18, fontweight=500)\n",
    "\n",
    "# Setting tick labels with font size and ensuring the font weight applies\n",
    "g.set_yticklabels(fontsize=18)\n",
    "g.set_xticklabels(fontsize=18)\n",
    "\n",
    "# Adjusting the limits of the plot\n",
    "g.set(xlim=(-25,45), ylim=(-1,12))\n",
    "\n",
    "# Saving the plot to a PDF file with the specified DPI and tight bounding box\n",
    "plt.savefig('./fig/cost_corr.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neuron correlation between different models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_theme(style=\"white\", palette='Dark2')\n",
    "hue_order = ['Llama2', 'Gemma', 'Mistral']\n",
    "# Define a custom color palette for the desired hue order\n",
    "custom_palette = sns.color_palette(\"YlGnBu\", len(hue_order))\n",
    "# Map the custom palette to the desired hue order\n",
    "hue_colors = dict(zip(hue_order, custom_palette))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr(indexes: list, filenames: list[str], topk: int=None):\n",
    "    if not topk:\n",
    "        topk = len(indexes[0])\n",
    "    common_set = set(indexes[0][:topk])\n",
    "    for index in indexes[1:]:\n",
    "        common_set.intersection_update(set(index[:topk]))\n",
    "    print(len(common_set)/topk)\n",
    "    df = pd.DataFrame({f'{filenames[i]}': [idx for idx in indexes[i] if idx in common_set] for i in range(len(indexes))})\n",
    "    ref = df.iloc[:, 0].to_list()\n",
    "    df_rank = df.map(lambda x: ref.index(x))\n",
    "    return common_set, df_rank\n",
    "\n",
    "def pair_wise_corr(indexes: list, filenames: list[str], topk: int=None):\n",
    "    n = len(indexes)\n",
    "    overlap, spearman = np.zeros((n, n)), np.zeros((n, n))\n",
    "    for i in range(n):\n",
    "        for j in range(i+1, n):\n",
    "            cs, rank = corr([indexes[i], indexes[j]], [filenames[i], filenames[j]], topk)\n",
    "            overlap[i, j] = len(cs) / topk\n",
    "            spearman[i, j] = rank.corr('spearman').iloc[0, 1]\n",
    "    return overlap, spearman"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_{}_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "models = [\n",
    "    'hh_harmless', \n",
    "    'rewardbench_safety',\n",
    "    'hh_helpful',\n",
    "    'shp',\n",
    "    'H4_stack_exchange',\n",
    "    'rewardbench_reasoning',\n",
    "    'IEFeedback'\n",
    "]\n",
    "llms = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "num_neurons = [17612, 22937, 34406]\n",
    "overlaps, spearmans = [], []\n",
    "index_lists = []\n",
    "i = 2\n",
    "for model in models:\n",
    "    for neuron_index in neuron_indexes:\n",
    "        index_path = [f'/data1/ /Alignment/hooked_llama/neuron_activation/{llms[i]}_sharegpt_ia3_ff_1_{model}_dpo_ia3_ff_{neuron_index.format(data_set)}.pt' for data_set in data_sets]\n",
    "        for path in index_path:\n",
    "            _, index, *_ = torch.load(path, weights_only=True)\n",
    "            index_list = index.tolist()\n",
    "            index_lists.append([(a, b) for a, b in index_list])\n",
    "            \n",
    "overlap, spearman = pair_wise_corr(index_lists, models, topk=num_neurons[i])\n",
    "overlaps.append(overlap)\n",
    "spearmans.append(spearman)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    (spearmans[0] + spearmans[0].T + np.eye(len(index_lists))),\n",
    "    index=models,\n",
    "    columns=models\n",
    ")\n",
    "labels = ['Harmless', 'Safety', 'Helpful', 'SHP', 'H4SE', 'Reasoning', 'IEFeedback']\n",
    "palette = sns.color_palette()\n",
    "colors = [palette[0]] * 2 + [palette[1]] * 2 + [palette[2]] * 2 + [palette[3]] # Colors for each label\n",
    "mask = np.eye(len(labels), dtype=bool)\n",
    "# plot a heatmap with annotation\n",
    "g = sns.heatmap(df, annot=False, mask=None, annot_kws={\"size\": 13}, cbar=True, cmap=\"YlGnBu\")\n",
    "g.set_xticklabels(labels=labels, rotation=45, fontsize=13)\n",
    "g.set_yticklabels(labels=labels, rotation=0, fontsize=13)\n",
    "# Setting colored labels for x-axis\n",
    "for tick_label, color in zip(g.get_xticklabels(), colors):\n",
    "    tick_label.set_fontweight(700)\n",
    "    tick_label.set_color(color)\n",
    "\n",
    "# Setting colored labels for y-axis\n",
    "for tick_label, color in zip(g.get_yticklabels(), colors):\n",
    "    tick_label.set_fontweight(700)\n",
    "    tick_label.set_color(color)\n",
    "\n",
    "cbar = g.collections[0].colorbar\n",
    "cbar.ax.tick_params(labelsize=16)  # Set the color bar font size to 15\n",
    "# Saving the plot to a PDF file with the specified DPI and tight bounding box\n",
    "plt.savefig('./fig/models_corr_gemma.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distribution of change scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dists = {}\n",
    "for model in ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']:\n",
    "    dist, *_ = torch.load(f'/data1/ /Alignment/hooked_llama/neuron_activation/{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_sft_vs_dpo_on_hh_harmless_sft_completion.pt', weights_only=True)\n",
    "    dists[model] = dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_neurons = {'Llama-2-7b-hf':17612, 'Mistral-7B-v0.1':22937, 'gemma-7b':34406}\n",
    "topk_dists = {}\n",
    "for k, v in dists.items():\n",
    "    topk = v.flatten().topk(20000).values\n",
    "    topk_dists[k] = topk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "fontsize = 22\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "custom_palette = ['#F67188', '#4BBB6C', '#39A7D7']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.displot(topk_dists, aspect=1.6)\n",
    "\n",
    "# Setting tick labels with font size and ensuring the font weight applies\n",
    "# g.set(xlim=(0.0317, 0.1), ylim=(0, 2500))\n",
    "g.set(xlim=(0.03, 0.1), ylim=(0, 2500))\n",
    "g.set_yticklabels(fontsize=fontsize-4, fontweight=500)\n",
    "g.set_xticklabels(fontsize=fontsize-4, fontweight=500)\n",
    "g.set_axis_labels(x_var='Change Score', y_var='Count', fontsize=fontsize, fontweight=500)\n",
    "new_labels = ['Llama2', 'Mistral', 'Gemma']\n",
    "for t, l in zip(g.legend.texts, new_labels):\n",
    "    t.set_text(l)\n",
    "\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.4,1), fontsize=fontsize, frameon=True)\n",
    "# custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize)\n",
    "g._legend.set_title(None)\n",
    "# g._legend.get_title().set_fontproperties(custom_font)\n",
    "# # Saving the plot to a PDF file with the specified DPI and tight bounding box\n",
    "plt.savefig('./fig/change_score_hist.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in ['Llama-2-7b-hf', 'gemma-7b', 'Mistral-7B-v0.1']:\n",
    "    print(skew(topk_dists[model]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neuron distribution in layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_df(indexes: list, names: list, n_groups: int, topk: int):\n",
    "    df_dict = {'layers': [], 'ranks': [], 'names': []}\n",
    "    for index, name in zip(indexes, names):\n",
    "        df_dict['ranks'] += (np.arange(topk) // (topk // n_groups)).tolist()\n",
    "        df_dict['layers'] += [layer.item() for layer, neuron in index[:topk]]\n",
    "        df_dict['names'] += [name] * topk\n",
    "    return pd.DataFrame(data=df_dict)\n",
    "\n",
    "def normalize_layers(values):\n",
    "    # print(values['names'])\n",
    "    layers = {'Llama-2-7b-hf': 32, 'Mistral-7B-v0.1': 32, 'gemma-7b': 28}\n",
    "    values['layers'] = float(values['layers']) / (layers[values['names']] - 2)\n",
    "    return values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "indexes = []\n",
    "for model in models:\n",
    "    _, index, *_ = torch.load(f'/data1/ /Alignment/hooked_llama/neuron_activation/{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_sft_vs_dpo_on_hh_harmless_sft_completion.pt')\n",
    "    indexes.append(index)\n",
    "df = create_df(indexes, models, 4, 20000)\n",
    "df = df.apply(normalize_layers, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_labels = ['Llama2', 'Mistral', 'Gemma']\n",
    "hue_order = ['Llama-2-7b-hf', 'gemma-7b', 'Mistral-7B-v0.1']\n",
    "# Define a custom color palette for the desired hue order\n",
    "# custom_palette = sns.color_palette(\"Dark2\", len(hue_order))\n",
    "# Map the custom palette to the desired hue order\n",
    "# hue_colors = dict(zip(hue_order, custom_palette))\n",
    "\n",
    "g = sns.violinplot(x='ranks', y='layers', hue='names', data=df)\n",
    "g.set_xlabel('Top K Neurons', fontsize=24, fontweight=500)\n",
    "g.set_ylabel('Layer Depth', fontsize=24, fontweight=500)\n",
    "g.set_xticklabels(['5000', '10000', '15000', '20000'], fontsize=24, fontweight=500)\n",
    "g.set_ylim(-0.18, 1.18)\n",
    "g.set_yticklabels(g.get_yticklabels(), fontsize=24, fontweight=500)\n",
    "\n",
    "for t, l in zip(g.legend_.texts, new_labels):\n",
    "    t.set_text(l)\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=15)\n",
    "g.legend_.set_title(None)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(1,1), fontsize=15, frameon=False)\n",
    "\n",
    "plt.savefig('./fig/layer_distribution.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neuron correlation between different datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_sets = [\n",
    "    'beavertails',\n",
    "    'hh_harmless',\n",
    "    'jailbreak_llms',\n",
    "    'hh_helpful',\n",
    "    'lima',\n",
    "    'rewardbench_reasoning',\n",
    "]\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_{}_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "models = [\n",
    "    'hh_harmless', \n",
    "]\n",
    "overlaps, spearmans = [], []\n",
    "index_lists = []\n",
    "for model in models:\n",
    "    for neuron_index in neuron_indexes:\n",
    "        index_path = [f'/data1/ /Alignment/hooked_llama/neuron_activation/Llama-2-7b-hf_sharegpt_ia3_ff_1_{model}_dpo_ia3_ff_{neuron_index.format(data_set)}.pt' for data_set in data_sets]\n",
    "        for path in index_path:\n",
    "            _, index, *_ = torch.load(path)\n",
    "            index_list = index.tolist()\n",
    "            index_lists.append([(a, b) for a, b in index_list])\n",
    "            \n",
    "overlap, spearman = pair_wise_corr(index_lists, data_sets, topk=20000)\n",
    "overlaps.append(overlap)\n",
    "spearmans.append(spearman)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    overlaps[0] + overlaps[0].T + np.eye(len(index_lists)),\n",
    "    index=data_sets,\n",
    "    columns=data_sets\n",
    ")\n",
    "labels = ['Beavertails', 'HH-Harmless', 'JailBreakLLMs', 'HH-Helpful', 'LIMA', 'Reasoning']\n",
    "palette = sns.color_palette(\"rocket\")\n",
    "colors = [palette[1]] * 3 + [palette[3]] * 2 + [palette[5]]  # Colors for each label\n",
    "# plot a heatmap with annotation\n",
    "g = sns.heatmap(df, annot=True, annot_kws={\"size\": 13}, cbar=False, cmap='rocket')\n",
    "g.set_xticklabels(labels=labels, rotation=45, fontsize=13)\n",
    "g.set_yticklabels(labels=labels, rotation=0, fontsize=13)\n",
    "# Setting colored labels for x-axis\n",
    "for tick_label, color in zip(g.get_xticklabels(), colors):\n",
    "    tick_label.set_color(color)\n",
    "\n",
    "# Setting colored labels for y-axis\n",
    "for tick_label, color in zip(g.get_yticklabels(), colors):\n",
    "    tick_label.set_color(color)\n",
    "# g.set_aspect(1)\n",
    "# Saving the plot to a PDF file with the specified DPI and tight bounding box\n",
    "plt.savefig('./fig/datasets_corr.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Patch Line Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for model in models:\n",
    "    for peft in pefts:\n",
    "        for index in neuron_indexes:\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': df.iloc[:,1],\n",
    "                'cost': df.iloc[:, 3],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "                'Neuron Type': 'Safety Neurons'\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index}_random_neurons.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': df.iloc[:,1],\n",
    "                'cost': df.iloc[:, 3],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "                'Neuron Type': 'Random Neurons'\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 40\n",
    "linewidth = 5\n",
    "markersize = 16\n",
    "sns.set_theme(style='ticks')\n",
    "# Filter data for 'Base' and 'SFT'\n",
    "df_base = df[df['peft'] == 'Base']\n",
    "df_sft = df[df['peft'] == 'SFT']\n",
    "\n",
    "# Initialize the plot\n",
    "fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharey=False, sharex=True)\n",
    "\n",
    "# Define the colors and markers for each model\n",
    "colors = {'Llama2': 'b', 'Mistral': 'g', 'Gemma': 'r'}\n",
    "markers = {'Llama2': 's', 'Mistral': 'p', 'Gemma': 'D'}\n",
    "\n",
    "# Plot for 'Base'\n",
    "for neuron_type, linestyle in zip(['Safety Neurons', 'Random Neurons'], ['-', '--']):\n",
    "    for model in df_base['Model'].unique():\n",
    "        subset = df_base[(df_base['Model'] == model) & (df_base['Neuron Type'] == neuron_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='cost', ax=axes[0],\n",
    "                color=colors[model],\n",
    "                marker=markers[model],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize\n",
    "            )\n",
    "\n",
    "# Plot for 'SFT'\n",
    "for neuron_type, linestyle in zip(['Safety Neurons', 'Random Neurons'], ['-', '--']):\n",
    "    for model in df_sft['Model'].unique():\n",
    "        subset = df_sft[(df_sft['Model'] == model) & (df_sft['Neuron Type'] == neuron_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='cost', ax=axes[1],\n",
    "                color=colors[model],\n",
    "                marker=markers[model],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize\n",
    "            )\n",
    "\n",
    "# Customize the plots\n",
    "axes[0].set_title('Patch Base with DPO', fontsize=fontsize)\n",
    "axes[0].set_xlabel('#Neurons', fontsize=fontsize)\n",
    "axes[0].set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "axes[1].set_title('Patch SFT with DPO', fontsize=fontsize)\n",
    "axes[1].set_xlabel('#Neurons', fontsize=fontsize)\n",
    "axes[1].set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "\n",
    "for ax in axes:\n",
    "    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))\n",
    "    plt.setp(ax.get_xticklabels(), fontsize=fontsize)\n",
    "    plt.setp(ax.get_yticklabels(), fontsize=fontsize)\n",
    "    \n",
    "# Create custom legend\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-', label='Safety Neurons'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='--', label='Random Neurons')\n",
    "]\n",
    "\n",
    "legend_elements = [Line2D([0], [0], color=colors[key], lw=linewidth, markersize=markersize, marker=markers[key], label=key) for key in colors.keys()] + legend_elements\n",
    "\n",
    "# axes[0].legend(handles=legend_elements, title='Model and Neuron Type')\n",
    "axes[0].legend(handles=legend_elements, title=None, ncol=2, fontsize=32, frameon=False, loc='upper left', bbox_to_anchor=(0.05, 1.8))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/patch_lineplot1.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sliding window"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for model in models:\n",
    "    for peft in pefts:\n",
    "        for index in neuron_indexes:\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index}_window_5.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': range(0, 31, 5),\n",
    "                'cost': df.iloc[:, -2],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "fontsize = 22\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "custom_palette = ['#F67188', '#4BBB6C', '#39A7D7']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Model', col='peft', kind='line', style='Model', dashes=False, markers=True, markersize=10, aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}})\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(0, 31, 5))  # 设置刻度位置\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "# g.axes[0, 0].set_ylim(-10, 6)\n",
    "g.axes[0, 0].set_title('Patch Base with DPO', fontsize=fontsize+2)\n",
    "# g.axes[1, 0].set_ylim(-14, 0)\n",
    "g.axes[0, 1].set_title('Patch SFT with DPO', fontsize=fontsize+2)\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize-4)\n",
    "g.legend.get_title().set_fontproperties(custom_font)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.91,0.9), fontsize=fontsize-4, frameon=True)\n",
    "plt.savefig('./fig/sliding_window_patch_lineplot.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 40\n",
    "linewidth = 5\n",
    "markersize = 16\n",
    "sns.set_theme(style='ticks')\n",
    "# Filter data for 'Base' and 'SFT'\n",
    "df_base = df[df['peft'] == 'Base']\n",
    "df_sft = df[df['peft'] == 'SFT']\n",
    "\n",
    "# Initialize the plot\n",
    "fig, axes = plt.subplots(2, 1, figsize=(14, 14), sharey=False, sharex=True)\n",
    "\n",
    "# Define the colors and markers for each model\n",
    "colors = {'Llama2': 'b', 'Mistral': 'g', 'Gemma': 'r'}\n",
    "markers = {'Llama2': 's', 'Mistral': 'p', 'Gemma': 'D'}\n",
    "\n",
    "# Plot for 'Base'\n",
    "for neuron_type, linestyle in zip(['w/', 'w/o'], ['-', '--']):\n",
    "    for model in df_base['Model'].unique():\n",
    "        subset = df_base[(df_base['Model'] == model) & (df_base['Patched'] == neuron_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='cost', ax=axes[0],\n",
    "                color=colors[model],\n",
    "                marker=markers[model],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize\n",
    "            )\n",
    "\n",
    "# Plot for 'SFT'\n",
    "for neuron_type, linestyle in zip(['w/', 'w/o'], ['-', '--']):\n",
    "    for model in df_sft['Model'].unique():\n",
    "        subset = df_sft[(df_sft['Model'] == model) & (df_sft['Patched'] == neuron_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='cost', ax=axes[1],\n",
    "                color=colors[model],\n",
    "                marker=markers[model],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize\n",
    "            )\n",
    "\n",
    "# Customize the plots\n",
    "axes[0].set_title('Patch Base with DPO', fontsize=fontsize)\n",
    "axes[0].set_xlabel('#Neurons', fontsize=fontsize)\n",
    "axes[0].set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "axes[1].set_title('Patch SFT with DPO', fontsize=fontsize)\n",
    "axes[1].set_xlabel('#Neurons', fontsize=fontsize)\n",
    "axes[1].set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "\n",
    "for ax in axes:\n",
    "    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))\n",
    "    plt.setp(ax.get_xticklabels(), fontsize=fontsize)\n",
    "    plt.setp(ax.get_yticklabels(), fontsize=fontsize)\n",
    "    \n",
    "# Create custom legend\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-', label='w/'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='--', label='w/o')\n",
    "]\n",
    "\n",
    "legend_elements = [Line2D([0], [0], color=colors[key], lw=linewidth, markersize=markersize, marker=markers[key], label=key) for key in colors.keys()] + legend_elements\n",
    "\n",
    "# axes[0].legend(handles=legend_elements, title='Model and Neuron Type')\n",
    "axes[0].legend(handles=legend_elements, title=None, ncol=2, fontsize=28, frameon=False, loc='upper left', bbox_to_anchor=(0.4,0.4))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/sliding_window_patch_lineplot.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5 random seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    'Method': ['Base', 'SFT', 'DPO'],\n",
    "    'Llama2': [2.21, -2.33, -11.85],\n",
    "    'Mistral': [-1.13, -7.50, -13.55],\n",
    "    'Gemma': [0.58, -9.54, -13.78],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "seeds = [1, 42, 66, 88, 3407]\n",
    "pefts = ['', '_sharegpt_ia3_ff_{}']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_{}_hh_harmless_dpo_ia3_ff'\n",
    "for seed in seeds:\n",
    "    for model in models:\n",
    "        for peft in pefts:\n",
    "            for index in neuron_indexes:\n",
    "                file_path = folder_template.format(model, peft.format(seed), model, seed) + f'/guided_by_{model}_sharegpt_ia3_ff_{seed}_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_{seed}_hh_harmless_dpo_ia3_ff_{index}.csv'\n",
    "                df = pd.read_csv(file_path)\n",
    "                new_df = pd.DataFrame({\n",
    "                    'topk': range(6),\n",
    "                    'cost': df.iloc[:, -2],\n",
    "                    'Model': model_name_map[model],\n",
    "                    'Patched': 'SFT' if peft else 'Base',\n",
    "                    'seed': seed,\n",
    "                    'Neuron Type': 'Safety Neurons'\n",
    "                })\n",
    "                dfs.append(new_df)\n",
    "                file_path = folder_template.format(model, peft.format(seed), model, seed) + f'/guided_by_{model}_sharegpt_ia3_ff_{seed}_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_{seed}_hh_harmless_dpo_ia3_ff_{index}_random_neurons.csv'\n",
    "                df = pd.read_csv(file_path)\n",
    "                new_df = pd.DataFrame({\n",
    "                    'topk': range(6),\n",
    "                    'cost': df.iloc[:, -2],\n",
    "                    'Model': model_name_map[model],\n",
    "                    'Patched': 'SFT' if peft else 'Base',\n",
    "                    'seed': seed,\n",
    "                    'Neuron Type': 'Random Neurons'\n",
    "                })\n",
    "                dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "fontsize = 22\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "custom_palette = ['#F67188', '#4BBB6C', '#39A7D7']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Model', col='Patched', kind='line', style='Neuron Type', aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}}, errorbar=('ci', 95))\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6))  # 设置刻度位置\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "# g.axes[0, 0].set_ylim(-10, 6)\n",
    "g.axes[0, 0].set_title('Patch Base with DPO', fontsize=fontsize+2)\n",
    "# g.axes[1, 0].set_ylim(-14, 0)\n",
    "g.axes[0, 1].set_title('Patch SFT with DPO', fontsize=fontsize+2)\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize-4)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.86,0.9), fontsize=fontsize-4, frameon=True)\n",
    "for text in g.legend.texts:\n",
    "    if text.get_text() == 'Model' or text.get_text() == 'Neuron Type':\n",
    "        text.set_fontproperties(custom_font)\n",
    "plt.savefig('./fig/patch_lineplot_errorbar.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 40\n",
    "linewidth = 5\n",
    "markersize = 14\n",
    "sns.set_theme(style='ticks')\n",
    "\n",
    "# Initialize the plot\n",
    "fig, ax = plt.subplots(1, 1, figsize=(14, 10))\n",
    "\n",
    "# Define the colors and markers for each model\n",
    "colors = {'Llama2': 'b', 'Mistral': 'g', 'Gemma': 'r'}\n",
    "markers = {'Llama2': 's', 'Mistral': 'p', 'Gemma': 'D'}\n",
    "\n",
    "# Plot for 'Base'\n",
    "for patch_type, linestyle in zip(['SFT', 'Base'], ['-', '--']):\n",
    "    for model in df['Model'].unique():\n",
    "        subset = df[(df['Model'] == model) & (df['Patched'] == patch_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='cost', ax=ax,\n",
    "                color=colors[model],\n",
    "                marker=markers[model],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize,\n",
    "                errorbar=('ci', 95)\n",
    "            )\n",
    "\n",
    "\n",
    "\n",
    "# Customize the plots\n",
    "ax.set_xlabel('#Neurons', fontsize=fontsize)\n",
    "ax.set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "\n",
    "ax.xaxis.set_major_locator(MaxNLocator(nbins=5))\n",
    "plt.setp(ax.get_xticklabels(), fontsize=fontsize)\n",
    "plt.setp(ax.get_yticklabels(), fontsize=fontsize)\n",
    "    \n",
    "# Create custom legend\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-', label='SFT'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='--', label='Base')\n",
    "]\n",
    "\n",
    "legend_elements = [Line2D([0], [0], color=colors[key], lw=linewidth, markersize=markersize, marker=markers[key], label=key) for key in colors.keys()] + legend_elements\n",
    "\n",
    "# axes[0].legend(handles=legend_elements, title='Model and Neuron Type')\n",
    "ax.legend(handles=legend_elements, title=None, ncol=2, fontsize=30, frameon=False, loc='upper left', bbox_to_anchor=(0.43,1.0))\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('./fig/5_random_seeds.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Different datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    'beavertails',\n",
    "    'hh_harmless',\n",
    "    'jailbreak_llms',\n",
    "    'hh_helpful',\n",
    "    'lima',\n",
    "    'rewardbench_reasoning',\n",
    "]\n",
    "labels = ['Beavertails', 'HH-Harmless', 'JailBreakLLMs', 'HH-Helpful', 'LIMA', 'Reasoning']\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_{}_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for dataset, label in zip(data_sets, labels):\n",
    "    for model in models:\n",
    "        for peft in pefts:\n",
    "            for index in neuron_indexes:\n",
    "                file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index.format(dataset)}.csv'\n",
    "                df = pd.read_csv(file_path)\n",
    "                new_df = pd.DataFrame({\n",
    "                    'topk': range(6),\n",
    "                    'cost': df.iloc[:, -2],\n",
    "                    'peft': 'SFT' if peft else 'Base',\n",
    "                    'Prompt Dataset': label,\n",
    "                    'model': model\n",
    "                })\n",
    "                dfs.append(new_df)\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "custom_palette = ['#F67188', '#B59A31', '#8AA730', '#4BBB6C', '#39A7D7', '#8F93F3', '#DB70F4']\n",
    "sns.set_palette(custom_palette)\n",
    "model_names = ['Llama2', 'Mistral', 'Gemma']\n",
    "patch_names = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Prompt Dataset', col='peft', row='model', kind='line', style='Prompt Dataset', aspect=1.5, facet_kws=dict(sharey=False), linewidth=2.5, dashes=False, markers=True, markersize=13)\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize, fontweight=500)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6))  # 设置刻度位置\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "        \n",
    "for i in range(len(model_names)):\n",
    "    for j in range(len(patch_names)):\n",
    "        g.axes[i, j].set_title(f'{patch_names[j]} ({model_names[i]})', fontsize=fontsize+2)\n",
    "        \n",
    "g.fig.subplots_adjust(hspace=0.2, wspace=0.25)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.11,0.465), fontsize=fontsize-4, frameon=False, ncol=2)\n",
    "# Access the legend object and modify font properties\n",
    "legend = g._legend\n",
    "legend.set_title(None)\n",
    "\n",
    "plt.savefig('./fig/different_datasets.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Different models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "dpos = [\n",
    "    'hh_harmless', \n",
    "    'rewardbench_safety',\n",
    "    'hh_helpful',\n",
    "    'shp',\n",
    "    'H4_stack_exchange',\n",
    "    'rewardbench_reasoning',\n",
    "    'IEFeedback'\n",
    "]\n",
    "labels = ['Harmless', 'Safety', 'Helpful', 'SHP', 'H4SE', 'Reasoning', 'IEFeedback']\n",
    "neuron_indexes = [\n",
    "    # 'base_vs_dpo_on_{}_base_completion',\n",
    "    # 'base_vs_dpo_on_{}_dpo_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "    # 'sft_vs_dpo_on_{}_dpo_completion'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_vs_{}_sharegpt_ia3_ff_1_hh_helpful_dpo_ia3_ff'\n",
    "for dpo, label in zip(dpos, labels):\n",
    "    for model in models:\n",
    "        for index in neuron_indexes:\n",
    "            file_path = folder_template.format(model, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_helpful_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_{dpo}_dpo_ia3_ff_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': range(6),\n",
    "                'cost': df.iloc[:, -2],\n",
    "                'Alignment Dataset': label,\n",
    "                'model': model,\n",
    "                'direction': 'helpful'\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "folder_template = '../results/arena/{}_sharegpt_ia3_ff_1_hh_helpful_dpo_ia3_ff_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for dpo, label in zip(dpos, labels):\n",
    "    for model in models:\n",
    "        for index in neuron_indexes:\n",
    "            file_path = folder_template.format(model, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_{dpo}_dpo_ia3_ff_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': range(6),\n",
    "                'cost': df.iloc[:, -2],\n",
    "                'Alignment Dataset': label,\n",
    "                'model': model,\n",
    "                'direction': 'harmless'\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "            \n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "custom_palette = ['#F67188', '#B59A31', '#8AA730', '#4BBB6C', '#39A7D7', '#8F93F3', '#DB70F4']\n",
    "markers = ['o', 's', '^', 'v', 'D', 'p', '*']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Alignment Dataset', style='Alignment Dataset', dashes=False, col='direction', kind='line', markers=True, aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}}, markersize=13)\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6)) \n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "# g.axes[0, 0].set_ylim(-10, 6)\n",
    "g.axes[0, 0].set_title(r'Helpfulness $\\rightarrow$ Safety', fontsize=fontsize+2)\n",
    "# g.axes[1, 0].set_ylim(-14, 0)\n",
    "g.axes[0, 1].set_title(r'Safety $\\rightarrow$ Helpfulness', fontsize=fontsize+2)\n",
    "# Define a custom font\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.88,0.9), fontsize=fontsize-4, frameon=True)\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize-4)\n",
    "g.legend.get_title().set_fontproperties(custom_font)\n",
    "plt.savefig('./fig/alignment_tax.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Single Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 40\n",
    "linewidth = 5\n",
    "markersize = 14\n",
    "sns.set_theme(style='ticks')\n",
    "palette = sns.color_palette()\n",
    "g = sns.relplot(x='topk', y='cost', data=df, style='Alignment Dataset', hue='Alignment Dataset', dashes=False, kind='line', aspect=1.4, facet_kws=dict(sharey=False), linewidth=linewidth, height=8, palette=palette, markers=True, markersize=markersize)\n",
    "g.set_axis_labels(x_var='#Neurons', y_var='Reward Scores', fontsize=fontsize, fontweight=500)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.24,0.29), fontsize=22, frameon=False, ncol=3)\n",
    "# Access the legend object and modify font properties\n",
    "legend = g._legend\n",
    "legend.set_title(None)\n",
    "# Define a custom font\n",
    "# custom_font = font_manager.FontProperties(style='normal', weight=700, size=15)\n",
    "# legend.get_title().set_fontproperties(custom_font)\n",
    "for ax in g.axes.flat:\n",
    "    ax.spines['right'].set_visible(True)\n",
    "    ax.spines['top'].set_visible(True)\n",
    "\n",
    "plt.savefig('./fig/different_models_grid.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Mutiple Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "model_names = ['Llama2', 'Mistral', 'Gemma']\n",
    "patch_names = [r'Helpfulness $\\rightarrow$ Safety', r'Safety $\\rightarrow$ Helpfulness']\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "custom_palette = ['#F67188', '#B59A31', '#8AA730', '#4BBB6C', '#39A7D7', '#8F93F3', '#DB70F4']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Alignment Dataset', style='Alignment Dataset', dashes=False, col='direction', row='model', kind='line', markers=True, aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}}, markersize=13)\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for i, ax in enumerate(g.axes.flat):\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6)) \n",
    "    ax.set_title(f\"{patch_names[i%len(patch_names)]} ({model_names[i//len(patch_names)]})\", fontsize=fontsize+2)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "\n",
    "# Define a custom font\n",
    "# sns.move_legend(g, 'upper left', bbox_to_anchor=(0.88,0.9), fontsize=fontsize-4, frameon=False, )\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.17,0.49), fontsize=fontsize-4, frameon=False, ncol=2)\n",
    "g.legend.set_title(None)\n",
    "plt.savefig('./fig/alignment_tax_all.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "g = sns.relplot(x='topk', y='cost', data=df, col='model', hue='Alignment Dataset', kind='line', aspect=1.5, facet_kws=dict(sharey=False), linewidth=3)\n",
    "g.set_axis_labels(x_var='TOP K Neurons', y_var='Reward Score', fontsize=fontsize, fontweight=500)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "g.axes[0, 0].set_title('Mistral', fontsize=fontsize)\n",
    "g.axes[0, 1].set_title('Gemma', fontsize=fontsize)\n",
    "g.fig.subplots_adjust(hspace=0.2)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.88,1.01), fontsize=18, frameon=True)\n",
    "# Access the legend object and modify font properties\n",
    "legend = g._legend\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=18)\n",
    "legend.get_title().set_fontproperties(custom_font)\n",
    "plt.savefig('./fig/different_models_mistral_gemma.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ablation study"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Wichi Models to Compare?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    'base_vs_dpo_on_hh_harmless_base_completion',\n",
    "    'base_vs_dpo_on_hh_harmless_dpo_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "    'sft_vs_dpo_on_hh_harmless_dpo_completion'\n",
    "]\n",
    "legends = [\n",
    "    ['Base and DPO', 'First Model'],\n",
    "    ['Base and DPO', 'Second Model'],\n",
    "    ['SFT and DPO', 'First Model'],\n",
    "    ['SFT and DPO', 'Second Model']\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for model in models:\n",
    "    for peft in pefts:\n",
    "        for index, legend in zip(neuron_indexes, legends):\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': range(6),\n",
    "                'cost': df.iloc[:, -2],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "                'Compared Models': legend[0],\n",
    "                'Generation From': legend[1]\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multiple Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "model_names = ['Llama2', 'Mistral', 'Gemma']\n",
    "patch_names = [r'Patch Base with DPO', r'Patch SFT with DPO']\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "custom_palette = ['#F67188', '#39A7D7', '#8F93F3', '#DB70F4']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Compared Models', style='Generation From', col='peft', row='Model', kind='line', markers=True, aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}}, markersize=13)\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for i, ax in enumerate(g.axes.flat):\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6)) \n",
    "    ax.set_title(f\"{patch_names[i%len(patch_names)]} ({model_names[i//len(patch_names)]})\", fontsize=fontsize+2)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "\n",
    "# Define a custom font\n",
    "# sns.move_legend(g, 'upper left', bbox_to_anchor=(0.88,0.9), fontsize=fontsize-4, frameon=False, )\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.62,0.55), fontsize=fontsize-4, frameon=False, ncol=1)\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize-4)\n",
    "for text in g.legend.texts:\n",
    "    if text.get_text() == 'Compared Models' or text.get_text() == 'Generation From':\n",
    "        text.set_fontproperties(custom_font)\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/which_model.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 24\n",
    "linewidth = 5\n",
    "markersize = 10\n",
    "sns.set_theme(style='ticks')\n",
    "# Initialize the plot\n",
    "\n",
    "fig, axes = plt.subplots(3, 2, figsize=(14, 14), sharey=False, sharex=True)\n",
    "\n",
    "# Define the colors and markers for each model\n",
    "colors = {'Base and DPO': 'b', 'SFT and DPO': 'g'}\n",
    "markers = {'Base and DPO': 's', 'SFT and DPO': 'p'}\n",
    "\n",
    "for i, model in enumerate(models):\n",
    "    for j, peft in enumerate(['Base', 'SFT']):\n",
    "        df_temp = df[(df['Model'] == model_name_map[model]) & (df['peft'] == peft)]\n",
    "        for gmodel, linestyle in zip(['First Model', 'Second Model'], ['-', '--']):\n",
    "            for cmodel in df_temp['Compared Models'].unique():\n",
    "                subset = df_temp[(df_temp['Compared Models'] == cmodel) & (df_temp['Generation'] == gmodel)]\n",
    "                if not subset.empty:\n",
    "                    sns.lineplot(\n",
    "                        data=subset,\n",
    "                        x='topk', y='cost', ax=axes[i][j],\n",
    "                        color=colors[cmodel],\n",
    "                        marker=markers[cmodel],\n",
    "                        linestyle=linestyle,\n",
    "                        linewidth=linewidth,\n",
    "                        markersize=markersize\n",
    "                    )\n",
    "\n",
    "# Customize the plots\n",
    "for i, patch_type in enumerate(['Patch Base with DPO', 'Patch SFT with DPO']):\n",
    "    for j, model_name in enumerate(models):\n",
    "        axes[j][i].set_title(f'{patch_type} ({model_name_map[model_name]})', fontsize=fontsize)\n",
    "        if i == 0:\n",
    "            axes[j][i].set_ylabel('Cost Scores', fontsize=fontsize)\n",
    "        else:\n",
    "            axes[j][i].set_ylabel(None, fontsize=fontsize)\n",
    "        if j == 2:\n",
    "            axes[j][i].set_xlabel('#Neurons', fontsize=fontsize)\n",
    "        axes[j][i].xaxis.set_major_locator(MaxNLocator(nbins=5))\n",
    "        plt.setp(axes[j][i].get_xticklabels(), fontsize=fontsize)\n",
    "        plt.setp(axes[j][i].get_yticklabels(), fontsize=fontsize)\n",
    "            \n",
    "\n",
    "\n",
    "    \n",
    "# # Create custom legend\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-', label='First Model'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='--', label='Second Model')\n",
    "]\n",
    "\n",
    "legend_elements = [Line2D([0], [0], color=colors[key], lw=linewidth, markersize=markersize, marker=markers[key], label=key) for key in colors.keys()] + legend_elements\n",
    "\n",
    "\n",
    "axes[1][1].legend(handles=legend_elements, title=None, ncol=1, fontsize=18, frameon=False, loc='upper left', bbox_to_anchor=(0.35,0.95))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/which_model.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training Free methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "# titiles = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "data_sets = [\n",
    "    # 'beavertails',\n",
    "    'hh_harmless',\n",
    "    # 'self_instruct',\n",
    "    # 'hh_helpful'\n",
    "]\n",
    "neuron_indexes = [\n",
    "    'difference_on_hh_prompt_last_token',\n",
    "    'std_on_hh_prompt_last_token',\n",
    "    'sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_sft_vs_dpo_on_hh_harmless_sft_completion',\n",
    "]\n",
    "labels = [\n",
    "    'Prompt Difference',\n",
    "    'Activation Variance',\n",
    "    'Safety Neurons'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for model in models[1:]:\n",
    "    for peft in pefts:\n",
    "        for index, label in zip(neuron_indexes, labels):\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': df.iloc[:,1],\n",
    "                'cost': df.iloc[:, 3],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "                'Method': label\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Single Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Method', row='peft', kind='line', aspect=2, facet_kws=dict(sharey=False), linewidth=3)\n",
    "g.set_axis_labels(x_var='TOP K Neurons', y_var='Cost Score', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "g.axes[0, 0].set_ylim(-10, 10)\n",
    "g.axes[0, 0].set_title('Patch Base with DPO', fontsize=fontsize)\n",
    "g.axes[1, 0].set_ylim(-14, 0)\n",
    "g.axes[1, 0].set_title('Patch SFT with DPO', fontsize=fontsize)\n",
    "g.fig.subplots_adjust(hspace=0.2)\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=18)\n",
    "g._legend.get_title().set_fontproperties(custom_font)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.58,1.02), fontsize=18, frameon=True)\n",
    "plt.savefig('./fig/different_methods.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multiple Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Method', row='peft', col='Model', kind='line', aspect=1.5, facet_kws=dict(sharey=False), linewidth=3)\n",
    "g.set_axis_labels(x_var='TOP K Neurons', y_var='Cost Score', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "for i in range(2):\n",
    "    for j in range(2):\n",
    "        g.axes[i, j].set_title(f'{patch_names[i]} ({model_names[j]})', fontsize=fontsize)\n",
    "g.fig.subplots_adjust(hspace=0.2)\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=18)\n",
    "g._legend.get_title().set_fontproperties(custom_font)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.86,1.02), fontsize=18, frameon=True)\n",
    "plt.savefig('./fig/different_methods_mistral_gemma.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Which Token Position to Compare?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['Llama-2-7b-hf', 'Mistral-7B-v0.1', 'gemma-7b']\n",
    "model_name_map = {'Llama-2-7b-hf': 'Llama2', 'Mistral-7B-v0.1': 'Mistral', 'gemma-7b': 'Gemma'}\n",
    "pefts = ['', '_sharegpt_ia3_ff_1']\n",
    "\n",
    "neuron_indexes = [\n",
    "    'sft_vs_dpo_on_hh_harmless_prompt',\n",
    "    'sft_vs_dpo_on_hh_harmless_prompt_last',\n",
    "    'sft_vs_dpo_on_hh_harmless_sft_completion'\n",
    "]\n",
    "labels = [\n",
    "    'Prompt (all tokens)',\n",
    "    'Prompt (last token)',\n",
    "    'Safety Neurons'\n",
    "]\n",
    "dfs = []\n",
    "folder_template = '../results/arena/{}{}_vs_{}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "for model in models:\n",
    "    for peft in pefts:\n",
    "        for index, label in zip(neuron_indexes, labels):\n",
    "            file_path = folder_template.format(model, peft, model) + f'/guided_by_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_idx_{model}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_{index}.csv'\n",
    "            df = pd.read_csv(file_path)\n",
    "            new_df = pd.DataFrame({\n",
    "                'topk': range(6),\n",
    "                'cost': df.iloc[:, -2],\n",
    "                'Model': model_name_map[model],\n",
    "                'peft': 'SFT' if peft else 'Base',\n",
    "                'Token Position': label\n",
    "            })\n",
    "            dfs.append(new_df)\n",
    "\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Multiple Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 22\n",
    "model_names = ['Llama2', 'Mistral', 'Gemma']\n",
    "patch_names = [r'Patch Base with DPO', r'Patch SFT with DPO']\n",
    "# custom_palette = ['#fe2727', '#178915', '#7777ff']\n",
    "sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif'})\n",
    "custom_palette = ['#F67188', '#4BBB6C', '#39A7D7']\n",
    "sns.set_palette(custom_palette)\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Token Position', style='Token Position', col='peft', row='Model', kind='line', dashes=False, markers=True, aspect=1.3, linewidth=2.5, facet_kws={'sharey': False, 'gridspec_kws': {'wspace': 0.25}}, markersize=13)\n",
    "g.set_axis_labels(x_var='Neurons (%)', y_var='Causal Effect', fontsize=fontsize)\n",
    "for i, ax in enumerate(g.axes.flat):\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "    ax.set_xticks(range(6)) \n",
    "    ax.set_title(f\"{patch_names[i%len(patch_names)]} ({model_names[i//len(patch_names)]})\", fontsize=fontsize+2)\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "\n",
    "# Define a custom font\n",
    "# sns.move_legend(g, 'upper left', bbox_to_anchor=(0.88,0.9), fontsize=fontsize-4, frameon=False, )\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.62,0.23), fontsize=fontsize-4, frameon=False, ncol=1)\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=fontsize-4)\n",
    "g.legend.get_title().set_fontproperties(custom_font)\n",
    "# for text in g.legend.texts:\n",
    "#     if text.get_text() == 'Compared Models' or text.get_text() == 'Generation From':\n",
    "#         text.set_fontproperties(custom_font)\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/different_positions.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 28\n",
    "sns.set_theme(style='ticks')\n",
    "palettes = sns.color_palette(\"tab10\")\n",
    "palette = [palettes[0], palettes[2], palettes[3]]\n",
    "model_names = ['Llama2', 'Mistral', 'Gemma']\n",
    "patch_names = ['Patch Base with DPO', 'Patch SFT with DPO']\n",
    "g = sns.relplot(x='topk', y='cost', data=df, hue='Token Position', style='Token Position', row='Model', col='peft', kind='line', aspect=1.5, facet_kws=dict(sharey=False), linewidth=6, markers=True, markersize=14, dashes=False, palette=palette)\n",
    "g.set_axis_labels(x_var='#Neurons', y_var='Cost Scores', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "for i in range(3):\n",
    "    for j in range(2):\n",
    "        g.axes[i, j].set_title(f'{patch_names[j]} ({model_names[i]})', fontsize=fontsize)\n",
    "g.fig.subplots_adjust(hspace=0.2, wspace=0.25)\n",
    "# Define a custom font\n",
    "# custom_font = font_manager.FontProperties(style='normal', weight=700, size=18)\n",
    "# g._legend.get_title().set_fontproperties(custom_font)\n",
    "g.legend.set_title(None)\n",
    "sns.move_legend(g, 'upper right', bbox_to_anchor=(0.88,0.65), fontsize=24, frameon=False)\n",
    "for ax in g.axes.flat:\n",
    "    ax.spines['right'].set_visible(True)\n",
    "    ax.spines['top'].set_visible(True)\n",
    "plt.savefig('./fig/different_positions.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Std vs rank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activations = []\n",
    "# model_name = 'Llama-2-7b-hf_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff'\n",
    "# model_name = 'Llama-2-7b-hf_sharegpt_ia3_ff_1'\n",
    "model_name = 'Llama-2-7b-hf'\n",
    "filenames = ['hh_harmless_prompt_last_token', 'hh_helpful_prompt_last_token']\n",
    "for filename in filenames:\n",
    "    activation = torch.load(f'/data1/ /Alignment/output/activations/{model_name}/{filename}.pt')\n",
    "    activations.append(activation)\n",
    "\n",
    "_, index, *_ = torch.load(f'/data1/ /Alignment/hooked_llama/neuron_activation/{model_name}_sharegpt_ia3_ff_1_hh_harmless_dpo_ia3_ff_sft_vs_dpo_on_hh_harmless_sft_completion.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_std = torch.concat(activations, 0).std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_groups = 4\n",
    "topk = 20000\n",
    "df_dict = {'std': [], 'ranks': []}\n",
    "df_dict['ranks'] += (np.arange(topk) // (topk // n_groups)).tolist()\n",
    "df_dict['std'] += [activation_std[layer, neuron].item() for layer, neuron in index[:topk]]\n",
    "df = pd.DataFrame(data=df_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.violinplot(x='ranks', y='std', data=df)\n",
    "g.set_xlabel('Top K Neurons', fontsize=18, fontweight=500)\n",
    "g.set_ylabel('Activation Std', fontsize=18, fontweight=500)\n",
    "g.set_xticklabels(range(0, topk+1, topk//n_groups), fontsize=18, fontweight=500)\n",
    "# g.set_ylim(-0.18, 1.18)\n",
    "g.set_yticklabels(g.get_yticklabels(), fontsize=18, fontweight=500)\n",
    "plt.savefig('./fig/std_distribution.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Safety Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the DataFrame from the provided table data\n",
    "data = {\n",
    "    'Method': ['Base', 'SFT', 'DPO'],\n",
    "    'Llama2_mean': [2.21, -2.33, -11.85],\n",
    "    'Llama2_std': [11.81, 15.26, 5.27],\n",
    "    'Mistral_mean': [-1.13, -7.50, -13.55],\n",
    "    'Mistral_std': [13.35, 11.58, 5.30],\n",
    "    'Gemma_mean': [0.58, -9.54, -13.78],\n",
    "    'Gemma_std': [12.82, 10.50, 5.45]\n",
    "}\n",
    "df = pd.DataFrame(data)\n",
    "# Melt the DataFrame to long format for Seaborn\n",
    "df_mean = df.melt(id_vars='Method', \n",
    "                  value_vars=['Llama2_mean', 'Mistral_mean', 'Gemma_mean'], \n",
    "                  var_name='Model', value_name='Mean')\n",
    "\n",
    "df_std = df.melt(id_vars='Method', \n",
    "                 value_vars=['Llama2_std', 'Mistral_std', 'Gemma_std'], \n",
    "                 var_name='Model', value_name='Std')\n",
    "\n",
    "# Clean up the model names\n",
    "df_mean['Model'] = df_mean['Model'].str.replace('_mean', '')\n",
    "df_std['Model'] = df_std['Model'].str.replace('_std', '')\n",
    "\n",
    "# Merge the mean and std DataFrames\n",
    "df_melted = pd.merge(df_mean, df_std, on=['Method', 'Model'])\n",
    "\n",
    "sns.set_theme(style=\"ticks\")\n",
    "fontsize = 40\n",
    "g = sns.catplot(\n",
    "    x='Method', y='Mean', hue='Model', data=df_melted, kind='bar', \n",
    "    palette=custom_palette, capsize=0.1, aspect=1.3, height=9, edgecolor='black'\n",
    ")\n",
    "# Customizing the plot\n",
    "plt.xlabel(None)\n",
    "plt.ylabel('Cost Scores', fontsize=fontsize)\n",
    "plt.xticks(fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.08,0.5), fontsize=fontsize, frameon=False, ncol=1)\n",
    "g.legend.set_title(None)\n",
    "plt.savefig('./fig/cost_scores.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_json('/home/ /Academics/Alignment-Low-Rank-Analysis/results/safety_guard.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\")\n",
    "fontsize = 20\n",
    "custom_palette = ['#F77189', '#96A430', '#35AEA4', '#A992F4', '#fe2727']\n",
    "sns.set_palette(custom_palette)\n",
    "titles = ['JailBreakLLMs', 'Beavertails', 'HarmBench', 'RedTeam']\n",
    "model_names = [\"Llama2\", \"Mistral\", \"Gemma\"]\n",
    "\n",
    "# Plot the bar plot with error bars\n",
    "g = sns.catplot(\n",
    "    x='Dataset', y='mean', hue='Type', data=df, col='Model', kind='bar', capsize=1, aspect=1.1, edgecolor='black',\n",
    "    facet_kws={'gridspec_kws': {'wspace': 0.1}}\n",
    ")\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, rotation=45)  # Adjust ticks appearance\n",
    "    ax.set_xticklabels(titles)\n",
    "    ax.set_xlabel(None)\n",
    "    ax.set_ylabel('Cost Scores', fontdict={'size': fontsize})\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(True)\n",
    "for i in range(1):\n",
    "    for j in range(3):\n",
    "        g.axes[i, j].set_title(f'{model_names[j]}', fontsize=fontsize+4)\n",
    "        \n",
    "sns.move_legend(g, 'upper right', bbox_to_anchor=(1.07,0.9), frameon=True, prop={'size': fontsize, 'weight': 500}, ncol=1)\n",
    "g.legend.set_title(None)\n",
    "\n",
    "plt.savefig('./fig/safety_guard.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = sns.catplot(\n",
    "    x='Dataset', y='mean', hue='Type', data=df[df['Model']==models[1]], kind='bar', \n",
    "    palette=custom_palette, capsize=0.1, aspect=1.3, height=9, edgecolor='black'\n",
    ")\n",
    "# Customizing the plot\n",
    "plt.xlabel(None)\n",
    "plt.ylabel('Cost Scores', fontsize=fontsize)\n",
    "plt.xticks(fontsize=fontsize, rotation=45)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.85,1.03), fontsize=fontsize, frameon=False, ncol=1)\n",
    "g.legend.set_title(None)\n",
    "# Set xtick labels to custom titles\n",
    "g.set_xticklabels(titles)\n",
    "plt.savefig('./fig/safety_guard_gemma.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.13 ('alignment')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fa5059266a314ab8b54f0ed734c52e1c70437da747820dac7ce3245ce71fcc13"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
