{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score\n",
    "from warnings import simplefilter\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "simplefilter(\"ignore\", category=ConvergenceWarning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification using top neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = [1, 42, 66, 88, 3419]\n",
    "topks = [1500, 1200, 900, 700, 500, 300, 150, 50, 30, 20, 10]\n",
    "train_name, test_names = 'beavertails', ['jailbreak_llms', 'harmbench', 'hh_harmless', 'red_team_attempts']\n",
    "clf_cls = LogisticRegression\n",
    "# param = {'max_iter': 100}\n",
    "param = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_neuron_f1, random_neuron_f1, last_random_neuron_f1, everywhere_random_neuron_f1 = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)\n",
    "for topk in topks[::-1]:\n",
    "    for test_name in test_names:\n",
    "        print(f'topk: {topk}, dataset: {test_name}')\n",
    "        data_train = f' / /Alignment/output/activations/harmless_prediction/{train_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top{topk}.pt'\n",
    "        data_test = f' / /Alignment/output/activations/harmless_prediction/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top{topk}.pt'\n",
    "        X_train, y_train = torch.load(data_train)\n",
    "        X_test, y_test = torch.load(data_test)\n",
    "        clf = clf_cls(**param)\n",
    "        clf.fit(X_train, y_train)\n",
    "        top_neuron_f1[test_name].append(accuracy_score(y_test, clf.predict(X_test)))\n",
    "\n",
    "        scores = []\n",
    "        for seed in seeds:\n",
    "            data_train = f' / /Alignment/output/activations/harmless_prediction/{train_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top{topk}_seed{seed}.pt'\n",
    "            data_test = f' / /Alignment/output/activations/harmless_prediction/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top{topk}_seed{seed}.pt'\n",
    "            X_train, y_train = torch.load(data_train)\n",
    "            X_test, y_test = torch.load(data_test)\n",
    "            clf = clf_cls(**param)\n",
    "            clf.fit(X_train, y_train)\n",
    "            scores.append(accuracy_score(y_test, clf.predict(X_test)))\n",
    "        random_neuron_f1[test_name].append(np.mean(scores))\n",
    "        \n",
    "        scores = []\n",
    "        for seed in seeds:\n",
    "            data_train = f' / /Alignment/output/activations/harmless_prediction/{train_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_last_top{topk}_seed{seed}.pt'\n",
    "            data_test = f' / /Alignment/output/activations/harmless_prediction/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_last_top{topk}_seed{seed}.pt'\n",
    "            X_train, y_train = torch.load(data_train)\n",
    "            X_test, y_test = torch.load(data_test)\n",
    "            clf = clf_cls(**param)\n",
    "            clf.fit(X_train, y_train)\n",
    "            scores.append(accuracy_score(y_test, clf.predict(X_test)))\n",
    "        last_random_neuron_f1[test_name].append(np.mean(scores))\n",
    "\n",
    "        # scores = []\n",
    "        # for seed in seeds:\n",
    "        #     data_train = f' / /Alignment/output/activations/harmless_prediction/{train_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_everywhere_top{topk}_seed{seed}.pt'\n",
    "        #     data_test = f' / /Alignment/output/activations/harmless_prediction/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_everywhere_top{topk}_seed{seed}.pt'\n",
    "        #     X_train, y_train = torch.load(data_train)\n",
    "        #     X_test, y_test = torch.load(data_test)\n",
    "        #     clf = clf_cls(**param)\n",
    "        #     clf.fit(X_train, y_train)\n",
    "        #     scores.append(f1_score(y_test, clf.predict(X_test)))\n",
    "        # everywhere_random_neuron_f1[test_name].append(np.mean(scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for test_name in test_names:\n",
    "    acc_maj = []\n",
    "    print(f'dataset: {test_name}')\n",
    "    data_test = f' / /Alignment/output/activations/harmless_prediction/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top1500.pt'\n",
    "    _, y_test = torch.load(data_test)\n",
    "    y_pred = torch.ones_like(y_test) if y_test.float().mean() > 0.5 else torch.zeros_like(y_test)\n",
    "    acc_maj.append(accuracy_score(y_test, y_pred))\n",
    "np.mean(acc_maj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../results/result.json', 'r') as f:\n",
    "    results = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {'safety_neuron': top_neuron_f1, 'random_neuron': random_neuron_f1, 'random_neuron_last': last_random_neuron_f1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../results/result_acc.json', 'w') as f:\n",
    "    json.dump(results, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib.font_manager as font_manager\n",
    "\n",
    "# sns.set_theme(style=\"whitegrid\", rc={'font.family': 'sans-serif', 'font.sans-serif': 'Helvetica'}, palette='Dark2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_df(results: list[dict], test_names: list[str], neuron_types: list[str], topk: list[int]):\n",
    "    df_dict = {'f1': [], 'topk': [], 'Test Dataset': [], 'Neuron Type': []}\n",
    "    for result, neuron_type in zip(results, neuron_types):\n",
    "        for test_name in test_names:\n",
    "            df_dict['f1'] += result[test_name]\n",
    "            df_dict['topk'] += topk\n",
    "            df_dict['Test Dataset'] += [test_name] * len(topk)\n",
    "            df_dict['Neuron Type'] += [neuron_type] * len(topk)\n",
    "    return pd.DataFrame(data=df_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../results/result_f1.json', 'r') as f:\n",
    "    results = json.load(f)\n",
    "df = create_df(results.values(), test_names=test_names, neuron_types=['Safety Neuron', 'Random Neuron', 'Random Neuron(last)'], topk=topks[::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 40\n",
    "linewidth = 5\n",
    "markersize = 14\n",
    "sns.set_theme(style='ticks')\n",
    "# Initialize the plot\n",
    "fig, ax = plt.subplots(1, 1, figsize=(16, 10))\n",
    "legend_name_map = {'jailbreak_llms': 'JailbreakLLMs', 'harmbench': 'HarmBench', 'hh_harmless':'HH-Harmless', 'red_team_attempts':'RedTeam'}\n",
    "# Define the colors and markers for each model\n",
    "colors = {legend_name_map[name]: c for name, c in zip(test_names, ['r','g','b','y'])}\n",
    "markers = {legend_name_map[name]: m for name, m in zip(test_names, ['s','P','D','X'])}\n",
    "\n",
    "# Plot for 'Base'\n",
    "for neuron_type, linestyle in zip(['Random Neuron', 'Random Neuron(last)', 'Safety Neuron'], ['--', '-.', '-']):\n",
    "    for dataset in df['Test Dataset'].unique():\n",
    "        subset = df[(df['Test Dataset'] == dataset) & (df['Neuron Type'] == neuron_type)]\n",
    "        if not subset.empty:\n",
    "            sns.lineplot(\n",
    "                data=subset,\n",
    "                x='topk', y='f1', ax=ax,\n",
    "                color=colors[legend_name_map[dataset]],\n",
    "                marker=markers[legend_name_map[dataset]],\n",
    "                linestyle=linestyle,\n",
    "                linewidth=linewidth,\n",
    "                markersize=markersize\n",
    "            )\n",
    "\n",
    "# Customize the plots\n",
    "ax.set_xlabel('#Neurons', fontsize=fontsize)\n",
    "ax.set_ylabel('Accuracy', fontsize=fontsize)\n",
    "\n",
    "# ax.xaxis.set_major_locator(MaxNLocator(nbins=5))\n",
    "plt.setp(ax.get_xticklabels(), fontsize=fontsize)\n",
    "plt.setp(ax.get_yticklabels(), fontsize=fontsize)\n",
    "    \n",
    "# Create custom legend\n",
    "legend_elements = [\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='--', label='Random Neuron'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-.', label='Random Neuron (last)'),\n",
    "    Line2D([0], [0], color='k', lw=linewidth, markersize=markersize, linestyle='-', label='Safety Neuron'),\n",
    "]\n",
    "\n",
    "legend_elements = [Line2D([0], [0], color=colors[key], lw=linewidth, markersize=markersize, marker=markers[key], label=key) for key in colors.keys()] + legend_elements\n",
    "\n",
    "# axes[0].legend(handles=legend_elements, title='Model and Neuron Type')\n",
    "ax.legend(handles=legend_elements, title=None, ncol=2, fontsize=30, frameon=False, loc='upper left', bbox_to_anchor=(0.15,0.5))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./fig/harm_prediction.pdf', dpi=300, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 18\n",
    "legend_name_map = {'jailbreak_llms': 'JailbreakLLMs', 'harmbench': 'HarmBench', 'hh_harmless':'HH-Harmless', 'red_team_attempts':'RedTeam'}\n",
    "g = sns.relplot(x='topk', y='f1', data=df, hue='Test Dataset', kind='line', style='Neuron Type', aspect=1.3, linewidth=3)\n",
    "g.set_axis_labels(x_var='Number of Neurons', y_var='Accuracy', fontsize=fontsize)\n",
    "for ax in g.axes.flat:\n",
    "    ax.tick_params(axis='both', labelsize=fontsize, width=2)  # Adjust ticks appearance\n",
    "sns.move_legend(g, 'upper left', bbox_to_anchor=(0.72,1), fontsize=12, frameon=True)\n",
    "# Access the legend object and modify font properties\n",
    "legend = g._legend\n",
    "# Define a custom font\n",
    "custom_font = font_manager.FontProperties(style='normal', weight=700, size=12)\n",
    "\n",
    "# Loop through the legend texts and apply font changes based on the label\n",
    "for text in legend.texts:\n",
    "    if text.get_text() in legend_name_map:\n",
    "        text.set_text(legend_name_map[text.get_text()])\n",
    "    if text.get_text() == 'Test Dataset' or text.get_text() == 'Neuron Type':\n",
    "        text.set_fontproperties(custom_font)\n",
    "\n",
    "# plt.savefig('./fig/harm_prediction.pdf', dpi=600, format='pdf', bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_melted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for neuron, result in results.items():\n",
    "    print(neuron)\n",
    "    acc_150, acc_1500 = [], []\n",
    "    for dataset, metrics in result.items():\n",
    "        # print(dataset, metrics[4], metrics[-1])\n",
    "        acc_150.append(metrics[4])\n",
    "        acc_1500.append(metrics[-1])\n",
    "    print(np.mean(acc_150), np.mean(acc_1500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12,8))\n",
    "colors = ['r', 'g', 'b', 'y']\n",
    "for i, test_name in enumerate(test_names):\n",
    "    plt.plot(random_neuron_f1[test_name], label=f'{test_name}_random', linestyle='dashed', c=colors[i])\n",
    "    plt.plot(last_random_neuron_f1[test_name], label=f'{test_name}_random_last', linestyle='dotted', c=colors[i])\n",
    "    plt.plot(top_neuron_f1[test_name], label=f'{test_name}_top', c=colors[i])   \n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_neuron_f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(y_test, clf.predict(X_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification using random neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train = f' / /Alignment/output/activations/{train_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top10000_random.pt'\n",
    "data_test = f' / /Alignment/output/activations/{test_name}_-1_Llama-2-7b-hf_sharegpt_ia3_ff_1_top10000_random.pt'\n",
    "X_train, y_train = torch.load(data_train)\n",
    "X_test, y_test = torch.load(data_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(y_test, clf.predict(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.13 ('alignment')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "fa5059266a314ab8b54f0ed734c52e1c70437da747820dac7ce3245ce71fcc13"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
