{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir('results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linestyles = ['-', '--', ':', '-.']\n",
    "\n",
    "model_name = \"vicuna-7b\"\n",
    "save_dir = \"results/mnist_data_attr\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"noisy_data_removal\" in save_dir:\n",
    "    alpha = 1.0\n",
    "    perturb_num = 1\n",
    "    # loss_orig_gd_name = f\"loss_orig_gd_{perturb_num}_{alpha}.npy\"\n",
    "    loss_orig_trans_name = f\"loss_orig_trans_{perturb_num}_{alpha}.npy\"\n",
    "    # loss_orig_gd = np.load(os.path.join(save_dir, loss_orig_gd_name))\n",
    "    loss_orig_trans = np.load(os.path.join(save_dir, loss_orig_trans_name))\n",
    "\n",
    "    print(f\"{loss_orig_trans.mean():.3f} ({loss_orig_trans.std():.3f})\")\n",
    "\n",
    "    # loss_cleaned_gd_name = f\"loss_cleaned_gd_{perturb_num}_{alpha}.npy\"\n",
    "    loss_cleaned_trans_name = f\"loss_cleaned_trans_{perturb_num}_{alpha}.npy\"\n",
    "    # loss_cleaned_gd = np.load(os.path.join(save_dir, loss_cleaned_gd_name))\n",
    "    loss_clean_trans = np.load(os.path.join(save_dir, loss_cleaned_trans_name))\n",
    "\n",
    "    print(f\"{loss_clean_trans.mean():.3f} ({loss_clean_trans.std():.3f})\")\n",
    "\n",
    "    # loss_random_gd_name = f\"loss_random_gd_{perturb_num}_{alpha}.npy\"\n",
    "    loss_random_trans_name = f\"loss_random_trans_{perturb_num}_{alpha}.npy\"\n",
    "    # loss_random_gd = np.load(os.path.join(save_dir, loss_random_gd_name))\n",
    "    loss_random_trans = np.load(os.path.join(save_dir, loss_random_trans_name))\n",
    "\n",
    "    print(f\"{loss_random_trans.mean():.3f} ({loss_random_trans.std():.3f})\")\n",
    "\n",
    "    loss_opt_trans_name = f\"loss_opt_trans_{perturb_num}_{alpha}.npy\"\n",
    "    loss_opt_trans = np.load(os.path.join(save_dir, loss_opt_trans_name))\n",
    "\n",
    "    print(f\"{loss_opt_trans.mean():.3f} ({loss_opt_trans.std():.3f})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"mnist_data_attr\" in save_dir:\n",
    "    acc_orig_mean = []\n",
    "    acc_orig_std = []\n",
    "    acc_rem_high_mean = []\n",
    "    acc_rem_high_std = []\n",
    "    acc_rem_low_mean = []\n",
    "    acc_rem_low_std = []\n",
    "    acc_rem_random_mean = []\n",
    "    acc_rem_random_std = []\n",
    "\n",
    "    alpha = 0.01\n",
    "\n",
    "    for perturb_num in [1,2,3,4,5,6]:\n",
    "\n",
    "        acc_orig_trans_name = f\"acc_orig_trans_{perturb_num}_{alpha}.npy\"\n",
    "        acc_orig_trans = np.load(os.path.join(save_dir, acc_orig_trans_name))\n",
    "        print(f\"{acc_orig_trans.mean():.3f} ({acc_orig_trans.std():.3f})\")\n",
    "        acc_orig_mean.append(acc_orig_trans.mean())\n",
    "        acc_orig_std.append(acc_orig_trans.std())\n",
    "\n",
    "        acc_rem_high_trans_name = f\"acc_rem_high_trans_{perturb_num}_{alpha}.npy\"\n",
    "        acc_rem_high_trans = np.load(os.path.join(save_dir, acc_rem_high_trans_name))\n",
    "        print(f\"{acc_rem_high_trans.mean():.3f} ({acc_rem_high_trans.std():.3f})\")\n",
    "        acc_rem_high_mean.append(acc_rem_high_trans.mean())\n",
    "        acc_rem_high_std.append(acc_rem_high_trans.std())\n",
    "\n",
    "        acc_rem_low_trans_name = f\"acc_rem_low_trans_{perturb_num}_{alpha}.npy\"\n",
    "        acc_rem_low_trans = np.load(os.path.join(save_dir, acc_rem_low_trans_name))\n",
    "        print(f\"{acc_rem_low_trans.mean():.3f} ({acc_rem_low_trans.std():.3f})\")\n",
    "        acc_rem_low_mean.append(acc_rem_low_trans.mean())\n",
    "        acc_rem_low_std.append(acc_rem_low_trans.std())\n",
    "\n",
    "        acc_rem_random_trans_name = f\"acc_rem_random_trans_{perturb_num}_{alpha}.npy\"\n",
    "        acc_rem_random_trans = np.load(os.path.join(save_dir, acc_rem_random_trans_name))\n",
    "        print(f\"{acc_rem_random_trans.mean():.3f} ({acc_rem_random_trans.std():.3f})\")\n",
    "        acc_rem_random_mean.append(acc_rem_random_trans.mean())\n",
    "        acc_rem_random_std.append(acc_rem_random_trans.std())\n",
    "    \n",
    "    # prepend orig to all of rem_high, rem_low, and rem_random\n",
    "    acc_rem_high_mean = [acc_orig_mean[0]] + acc_rem_high_mean\n",
    "    acc_rem_high_std = [acc_orig_std[0]] + acc_rem_high_std\n",
    "\n",
    "    acc_rem_low_mean = [acc_orig_mean[0]] + acc_rem_low_mean\n",
    "    acc_rem_low_std = [acc_orig_std[0]] + acc_rem_low_std\n",
    "\n",
    "    acc_rem_random_mean = [acc_orig_mean[0]] + acc_rem_random_mean\n",
    "    acc_rem_random_std = [acc_orig_std[0]] + acc_rem_random_std\n",
    "    \n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    # plt.plot(acc_orig_mean, label=\"all data\")\n",
    "    # plt.fill_between(range(len(acc_orig_mean)), np.array(acc_orig_mean)-np.array(acc_orig_std), np.array(acc_orig_mean)+np.array(acc_orig_std), alpha=0.3)\n",
    "    plt.plot(acc_rem_high_mean, label=\"remove high\", linestyle=linestyles[0])\n",
    "    plt.fill_between(range(len(acc_rem_high_mean)), np.array(acc_rem_high_mean)-np.array(acc_rem_high_std), np.array(acc_rem_high_mean)+np.array(acc_rem_high_std), alpha=0.3)\n",
    "    plt.plot(acc_rem_low_mean, label=\"remove low\", linestyle=linestyles[1])\n",
    "    plt.fill_between(range(len(acc_rem_low_mean)), np.array(acc_rem_low_mean)-np.array(acc_rem_low_std), np.array(acc_rem_low_mean)+np.array(acc_rem_low_std), alpha=0.3)\n",
    "    plt.plot(acc_rem_random_mean, label=\"remove random\", linestyle=linestyles[2])\n",
    "    plt.fill_between(range(len(acc_rem_random_mean)), np.array(acc_rem_random_mean)-np.array(acc_rem_random_std), np.array(acc_rem_random_mean)+np.array(acc_rem_random_std), alpha=0.3)\n",
    "    # set x ticks as [3,4,5,6]\n",
    "    plt.xticks(range(len(acc_rem_high_mean)), [0,1,2,3,4,5,6])\n",
    "    plt.xlabel(\"No. removed\")\n",
    "    plt.ylabel(\"Accuracy\")\n",
    "    plt.legend()\n",
    "\n",
    "    # save the plot\n",
    "    plt.savefig(f\"figs/mnist_data_attr_acc.pdf\", bbox_inches=\"tight\", dpi=300)\n",
    "\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_data_attr\" in save_dir:\n",
    "    perturb_num = 5\n",
    "    alpha=0.01\n",
    "    acc_orig_trans_name = f\"acc_orig_trans_{perturb_num}_{alpha}.npy\"\n",
    "    acc_rem_low_trans_name = f\"acc_rem_low_trans_{perturb_num}_{alpha}.npy\"\n",
    "    acc_rem_high_trans_name = f\"acc_rem_high_trans_{perturb_num}_{alpha}.npy\"\n",
    "    acc_random_trans_name = f\"acc_random_trans_{perturb_num}_{alpha}.npy\"\n",
    "\n",
    "    acc_orig_trans = np.load(os.path.join(save_dir, acc_orig_trans_name))\n",
    "    acc_rem_low_trans = np.load(os.path.join(save_dir, acc_rem_low_trans_name))\n",
    "    acc_rem_high_trans = np.load(os.path.join(save_dir, acc_rem_high_trans_name))\n",
    "    acc_random_trans = np.load(os.path.join(save_dir, acc_random_trans_name))\n",
    "\n",
    "    print(f\"{np.mean(acc_orig_trans):.3f} {np.std(acc_orig_trans):.3f}\")\n",
    "    print(f\"{np.mean(acc_rem_low_trans):.3f} {np.std(acc_rem_low_trans):.3f}\")\n",
    "    print(f\"{np.mean(acc_rem_high_trans):.3f} {np.std(acc_rem_high_trans):.3f}\")\n",
    "    print(f\"{np.mean(acc_random_trans):.3f} {np.std(acc_random_trans):.3f}\")\n",
    "    # # use accuracy\n",
    "    # new_loss_orig_trans = loss_clean_trans.copy()\n",
    "    # new_loss_orig_trans[loss_orig_trans < 1.0] = 1\n",
    "    # new_loss_orig_trans[loss_orig_trans > 1.0] = 0\n",
    "    # print(f\"{new_loss_orig_trans.mean():.3f} ({new_loss_orig_trans.std():.3f})\")\n",
    "\n",
    "    # new_loss_clean_trans = loss_clean_trans.copy()\n",
    "    # new_loss_clean_trans[loss_clean_trans < 1.0] = 1\n",
    "    # new_loss_clean_trans[loss_clean_trans > 1.0] = 0\n",
    "    # print(f\"{new_loss_clean_trans.mean():.3f} ({new_loss_clean_trans.std():.3f})\")\n",
    "\n",
    "    # new_loss_random_trans = loss_random_trans.copy()\n",
    "    # new_loss_random_trans[loss_random_trans < 1.0] = 1\n",
    "    # new_loss_random_trans[loss_random_trans > 1.0] = 0\n",
    "    # print(f\"{new_loss_random_trans.mean():.3f} ({new_loss_random_trans.std():.3f})\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm\" in save_dir and not any([item in save_dir for item in [\"layers\", \"order\", \"pos\", \"other\"]]):\n",
    "    num_removes = [4,7,10,13,16]\n",
    "    is_corrupt = False\n",
    "    use_label_pos = False\n",
    "    acc_orig_mean = []\n",
    "    acc_orig_ste = []\n",
    "    acc_rem_high_mean = []\n",
    "    acc_rem_high_ste = []\n",
    "    acc_rem_low_mean = []\n",
    "    acc_rem_low_ste = []\n",
    "    acc_rem_random_mean = []\n",
    "    acc_rem_random_ste = []\n",
    "\n",
    "    for num_remove in num_removes:\n",
    "        acc_orig_name = f\"acc_orig_{model_name}_{num_remove}{'_corrupt' if is_corrupt else '_remove'}{'_use_label_pos' if use_label_pos else ''}.npy\"\n",
    "        acc_orig = np.load(os.path.join(save_dir, acc_orig_name))\n",
    "        acc_orig_mean.append(acc_orig.mean())\n",
    "        acc_orig_ste.append(acc_orig.std() / np.sqrt(len(acc_orig)))\n",
    "\n",
    "        acc_rem_high_name = f\"acc_rem_high_{model_name}_{num_remove}{'_corrupt' if is_corrupt else '_remove'}{'_use_label_pos' if use_label_pos else ''}.npy\"\n",
    "        acc_rem_high = np.load(os.path.join(save_dir, acc_rem_high_name))\n",
    "        acc_rem_high_mean.append(acc_rem_high.mean())\n",
    "        acc_rem_high_ste.append(acc_rem_high.std() / np.sqrt(len(acc_rem_high)))\n",
    "\n",
    "        acc_rem_low_name = f\"acc_rem_low_{model_name}_{num_remove}{'_corrupt' if is_corrupt else '_remove'}{'_use_label_pos' if use_label_pos else ''}.npy\"\n",
    "        acc_rem_low = np.load(os.path.join(save_dir, acc_rem_low_name))\n",
    "        acc_rem_low_mean.append(acc_rem_low.mean())\n",
    "        acc_rem_low_ste.append(acc_rem_low.std() / np.sqrt(len(acc_rem_low)))\n",
    "\n",
    "        acc_rem_random_name = f\"acc_rem_random_{model_name}_{num_remove}{'_corrupt' if is_corrupt else '_remove'}{'_use_label_pos' if use_label_pos else ''}.npy\"\n",
    "        acc_rem_random = np.load(os.path.join(save_dir, acc_rem_random_name))\n",
    "        acc_rem_random_mean.append(acc_rem_random.mean())\n",
    "        acc_rem_random_ste.append(acc_rem_random.std() / np.sqrt(len(acc_rem_random)))\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    # append the mean and std of the original data to each of the other data\n",
    "    acc_rem_high_mean = [acc_orig_mean[0]] + acc_rem_high_mean\n",
    "    acc_rem_high_ste = [acc_orig_ste[0]] + acc_rem_high_ste\n",
    "\n",
    "    acc_rem_low_mean = [acc_orig_mean[0]] + acc_rem_low_mean\n",
    "    acc_rem_low_ste = [acc_orig_ste[0]] + acc_rem_low_ste\n",
    "\n",
    "    acc_rem_random_mean = [acc_orig_mean[0]] + acc_rem_random_mean\n",
    "    acc_rem_random_ste = [acc_orig_ste[0]] + acc_rem_random_ste\n",
    "\n",
    "    print(f\"remove low: {acc_rem_low_mean[0]:.3f} ({acc_rem_low_ste[0]:.3f})\")\n",
    "    # print(f\"remove high: {acc_rem_high_mean[3]:.3f} ({acc_rem_high_ste[3]:.3f})\")\n",
    "    print(f\"remove low: {acc_rem_low_mean[2]:.3f} ({acc_rem_low_ste[2]:.3f})\")\n",
    "    # print(f\"remove random: {acc_rem_random_mean[3]:.3f} ({acc_rem_random_ste[3]:.3f})\")\n",
    "\n",
    "    # plt.plot(acc_orig_mean, label=\"all data\")\n",
    "    # plt.fill_between(range(len(acc_orig_mean)), np.array(acc_orig_mean)-np.array(acc_orig_ste), np.array(acc_orig_mean)+np.array(acc_orig_ste), alpha=0.2)\n",
    "    # plt.plot(acc_rem_high_mean, label=\"corrupt high\" if is_corrupt else \"remove high\")\n",
    "    # plt.fill_between(range(len(acc_rem_high_mean)), np.array(acc_rem_high_mean)-np.array(acc_rem_high_ste), np.array(acc_rem_high_mean)+np.array(acc_rem_high_ste), alpha=0.2)\n",
    "    # plt.plot(acc_rem_low_mean, label=\"corrupt low\" if is_corrupt else \"remove low\")\n",
    "    # plt.fill_between(range(len(acc_rem_low_mean)), np.array(acc_rem_low_mean)-np.array(acc_rem_low_ste), np.array(acc_rem_low_mean)+np.array(acc_rem_low_ste), alpha=0.2)\n",
    "    # plt.plot(acc_rem_random_mean, label=\"corrupt random\" if is_corrupt else \"remove random\")\n",
    "    # plt.fill_between(range(len(acc_rem_random_mean)), np.array(acc_rem_random_mean)-np.array(acc_rem_random_ste), np.array(acc_rem_random_mean)+np.array(acc_rem_random_ste), alpha=0.2)\n",
    "\n",
    "    if \"cls_llm_comp\" in save_dir:\n",
    "        # plot error bars\n",
    "        plt.errorbar(range(len(acc_rem_high_mean)), acc_rem_high_mean, yerr=acc_rem_high_ste, label=\"remove high\", fmt='o-', capsize=5, linestyle=linestyles[0])\n",
    "        plt.errorbar(range(len(acc_rem_low_mean)), acc_rem_low_mean, yerr=acc_rem_low_ste, label=\"remove low\", fmt='o-', capsize=5, linestyle=linestyles[1])\n",
    "        plt.errorbar(range(len(acc_rem_random_mean)), acc_rem_random_mean, yerr=acc_rem_random_ste, label=\"remove random\", fmt='o-', capsize=5, linestyle=linestyles[2])\n",
    "    else:\n",
    "        plt.plot(acc_rem_high_mean, label=\"corrupt high\" if is_corrupt else \"remove high\", linestyle=linestyles[0])\n",
    "        plt.fill_between(range(len(acc_rem_high_mean)), np.array(acc_rem_high_mean)-np.array(acc_rem_high_ste), np.array(acc_rem_high_mean)+np.array(acc_rem_high_ste), alpha=0.2)\n",
    "        plt.plot(acc_rem_low_mean, label=\"corrupt low\" if is_corrupt else \"remove low\", linestyle=linestyles[1])\n",
    "        plt.fill_between(range(len(acc_rem_low_mean)), np.array(acc_rem_low_mean)-np.array(acc_rem_low_ste), np.array(acc_rem_low_mean)+np.array(acc_rem_low_ste), alpha=0.2)\n",
    "        plt.plot(acc_rem_random_mean, label=\"corrupt random\" if is_corrupt else \"remove random\", linestyle=linestyles[2])\n",
    "        plt.fill_between(range(len(acc_rem_random_mean)), np.array(acc_rem_random_mean)-np.array(acc_rem_random_ste), np.array(acc_rem_random_mean)+np.array(acc_rem_random_ste), alpha=0.2)\n",
    "\n",
    "\n",
    "    # set x ticks as [3,4,5,6]\n",
    "    plt.title(model_name, fontsize=25)\n",
    "    plt.xticks(range(1+len(acc_orig_mean)), [0] + num_removes)\n",
    "    plt.xlabel(f\"No. {'corrupted' if is_corrupt else 'removed'}\")\n",
    "    plt.ylabel(\"Accuracy\")\n",
    "    plt.legend(loc=\"lower left\")\n",
    "\n",
    "    plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_acc{'_corrupt' if is_corrupt else '_remove'}{'_use_label_pos' if use_label_pos else ''}.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm_other\" in save_dir:\n",
    "    num_remove = 10\n",
    "    rem_low_accs = []\n",
    "    wall_times = []\n",
    "\n",
    "    methods = [\n",
    "        \"datamodel\",\n",
    "        # 'attention',\n",
    "        # 'influence',\n",
    "        # 'ig',\n",
    "        # # 'sig',\n",
    "        # 'lime',\n",
    "        # 'nguyen_infl',\n",
    "        # 'vinay_infl',\n",
    "        # 'random'\n",
    "    ]\n",
    "\n",
    "    for method in methods:\n",
    "        acc_name = f\"acc_rem_low_{method}_{model_name}_{num_remove}_remove.npy\"\n",
    "        acc = np.load(os.path.join(save_dir, acc_name))\n",
    "        rem_low_accs.append(acc)\n",
    "\n",
    "        wall_time_name = f\"wall_time_{method}_{model_name}_{num_remove}_remove.npy\"\n",
    "        wall_time = np.load(os.path.join(save_dir, wall_time_name))\n",
    "        wall_times.append(wall_time)\n",
    "    \n",
    "    # print mean and ste of the low accuracy results\n",
    "    print(f\"Acc:\")\n",
    "    for i, method in enumerate(methods):\n",
    "        mean = rem_low_accs[i].mean()\n",
    "        ste = rem_low_accs[i].std() / np.sqrt(len(rem_low_accs[i]))\n",
    "        print(f\"{method}: {mean:.3f} ({ste:.4f})\")\n",
    "\n",
    "    print(f\"Wall time:\")\n",
    "    for i, method in enumerate(methods):\n",
    "        mean = wall_times[i].mean()\n",
    "        ste = wall_times[i].std() / np.sqrt(len(wall_times[i]))\n",
    "        print(f\"{method}: {mean:.3f} ({ste:.4f})\")\n",
    "    \n",
    "    # plot results\n",
    "\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    # plot box plots\n",
    "    plt.boxplot(rem_low_accs, labels=methods)\n",
    "    plt.ylabel(\"Accuracy\")\n",
    "\n",
    "    # plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_acc.pdf\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"detect_llm\" in save_dir and not any([item in save_dir for item in [\"proj\", \"layers\"]]):\n",
    "    num_perturbed = 4\n",
    "    infl_name = f\"fraction_checked_infl_{model_name}_{num_perturbed}.npy\"\n",
    "    fraction_checked_infl = np.load(os.path.join(save_dir, infl_name))\n",
    "\n",
    "    loo_name = f\"fraction_checked_loo_{model_name}_{num_perturbed}.npy\"\n",
    "    fraction_checked_loo = np.load(os.path.join(save_dir, loo_name))\n",
    "\n",
    "    random_name = f\"fraction_checked_random_{model_name}_{num_perturbed}.npy\"\n",
    "    fraction_checked_random = np.load(os.path.join(save_dir, random_name))\n",
    "\n",
    "    # get the mean and std of the fraction checked\n",
    "    fraction_checked_infl_mean = fraction_checked_infl.mean(axis=0)\n",
    "    fraction_checked_infl_ste = fraction_checked_infl.std(axis=0) / np.sqrt(len(fraction_checked_infl))\n",
    "\n",
    "    fraction_checked_loo_mean = fraction_checked_loo.mean(axis=0)\n",
    "    fraction_checked_loo_ste = fraction_checked_loo.std(axis=0) / np.sqrt(len(fraction_checked_loo))\n",
    "\n",
    "    fraction_checked_random_mean = fraction_checked_random.mean(axis=0)\n",
    "    fraction_checked_random_ste = fraction_checked_random.std(axis=0) / np.sqrt(len(fraction_checked_random))\n",
    "\n",
    "    # prepend a 0 to the fraction checked\n",
    "    fraction_checked_infl_mean = np.insert(fraction_checked_infl_mean, 0, 0)\n",
    "    fraction_checked_infl_ste = np.insert(fraction_checked_infl_ste, 0, 0)\n",
    "\n",
    "    fraction_checked_loo_mean = np.insert(fraction_checked_loo_mean, 0, 0)\n",
    "    fraction_checked_loo_ste = np.insert(fraction_checked_loo_ste, 0, 0)\n",
    "\n",
    "    fraction_checked_random_mean = np.insert(fraction_checked_random_mean, 0, 0)\n",
    "    fraction_checked_random_ste = np.insert(fraction_checked_random_ste, 0, 0)\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    plt.plot(fraction_checked_infl_mean, label=\"DETAIL\", linestyle=linestyles[0])\n",
    "    plt.fill_between(range(len(fraction_checked_infl_mean)), np.array(fraction_checked_infl_mean)-np.array(fraction_checked_infl_ste), np.array(fraction_checked_infl_mean)+np.array(fraction_checked_infl_ste), alpha=0.2)\n",
    "    plt.plot(fraction_checked_loo_mean, label=\"LOO\", linestyle=linestyles[1])\n",
    "    plt.fill_between(range(len(fraction_checked_loo_mean)), np.array(fraction_checked_loo_mean)-np.array(fraction_checked_loo_ste), np.array(fraction_checked_loo_mean)+np.array(fraction_checked_loo_ste), alpha=0.2)\n",
    "    plt.plot(fraction_checked_random_mean, label=\"random\", linestyle=linestyles[2])\n",
    "    plt.fill_between(range(len(fraction_checked_random_mean)), np.array(fraction_checked_random_mean)-np.array(fraction_checked_random_ste), np.array(fraction_checked_random_mean)+np.array(fraction_checked_random_ste), alpha=0.2)\n",
    "    plt.title(model_name, fontsize=25)\n",
    "    plt.xlabel(\"No. demos. checked\")\n",
    "    plt.ylabel(\"Fraction identified\")\n",
    "    plt.legend()\n",
    "\n",
    "    plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_fraction_checked.pdf\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"detect_llm\" in save_dir and not any([item in save_dir for item in [\"proj\", \"layers\"]]):\n",
    "    num_perturbed = 4\n",
    "\n",
    "    save_dirs = [\n",
    "        'results/detect_llm_ag_news',\n",
    "        'results/detect_llm_sst2',\n",
    "        'results/detect_llm_rotten_tomatoes',\n",
    "        'results/detect_llm_subj',\n",
    "    ]\n",
    "\n",
    "    wall_times_infl = []\n",
    "    wall_times_loo = []\n",
    "\n",
    "    for save_dir in save_dirs:\n",
    "        wall_time_infl_name = f\"wall_time_infl_{model_name}_{num_perturbed}.npy\"\n",
    "        wall_time_infl = np.load(os.path.join(save_dir, wall_time_infl_name))\n",
    "\n",
    "        wall_time_loo_name = f\"wall_time_loo_{model_name}_{num_perturbed}.npy\"\n",
    "        wall_time_loo = np.load(os.path.join(save_dir, wall_time_loo_name))\n",
    "\n",
    "        wall_times_infl.append(wall_time_infl)\n",
    "        wall_times_loo.append(wall_time_loo)\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    # plot bar graph\n",
    "    fig, ax = plt.subplots()\n",
    "    width = 0.35\n",
    "\n",
    "    x = np.arange(len(save_dirs))\n",
    "    ax.bar(x - width/2, [np.mean(wall_times_infl[i]) for i in range(len(wall_times_infl))], width, label='DETAIL')\n",
    "    ax.bar(x + width/2, [np.mean(wall_times_loo[i]) for i in range(len(wall_times_loo))], width, label='LOO', hatch='//')\n",
    "    # draw y = 3s\n",
    "    ax.axhline(y=3, color='firebrick', linestyle=linestyles[1], label='3s')\n",
    "\n",
    "    # draw y=0.3s\n",
    "    ax.axhline(y=0.3, color='darkgreen', linestyle=linestyles[2], label=r'0.3s (10$\\times$)')\n",
    "\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(['AG N.', 'SST-2', 'R.T.', 'Subj'])\n",
    "\n",
    "    ax.set_ylabel(\"Wall time (s)\")\n",
    "\n",
    "    ax.legend()\n",
    "\n",
    "    plt.savefig(f\"figs/detect_llm_{model_name}_wall_time.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"detect_llm_layers\" in save_dir :\n",
    "    num_perturbed = 4\n",
    "    all_fraction_checked_infl_mean = []\n",
    "    all_fraction_checked_infl_ste = []\n",
    "    for layer_num in [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]:\n",
    "        infl_name = f\"fraction_checked_infl_{model_name}_{num_perturbed}_{layer_num}.npy\"\n",
    "        fraction_checked_infl = np.load(os.path.join(save_dir, infl_name))\n",
    "\n",
    "        # get the mean and std of the fraction checked\n",
    "        fraction_checked_infl = fraction_checked_infl.mean(axis=1)\n",
    "        fraction_checked_infl_mean = fraction_checked_infl.mean(axis=0)\n",
    "        fraction_checked_infl_ste = fraction_checked_infl.std(axis=0) / np.sqrt(len(fraction_checked_infl))\n",
    "\n",
    "        all_fraction_checked_infl_mean.append(fraction_checked_infl_mean)\n",
    "        all_fraction_checked_infl_ste.append(fraction_checked_infl_ste)\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    plt.plot(all_fraction_checked_infl_mean, label=\"DETAIL\")\n",
    "    plt.fill_between(range(len(all_fraction_checked_infl_mean)), np.array(all_fraction_checked_infl_mean)-np.array(all_fraction_checked_infl_ste), np.array(all_fraction_checked_infl_mean)+np.array(all_fraction_checked_infl_ste), alpha=0.2)\n",
    "    plt.xlabel(\"Layer number\")\n",
    "    plt.ylabel(\"Fraction identified\")\n",
    "    plt.xticks(list(range(0, len(all_fraction_checked_infl_mean)-2, 2)) + [len(all_fraction_checked_infl_mean)-1], [1,5,9,13,17,21,25,31])\n",
    "    plt.legend()\n",
    "\n",
    "    plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_fraction_checked.pdf\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"detect_llm_proj\" in save_dir:\n",
    "    all_fraction_checked_infl = []\n",
    "    all_wall_time = []\n",
    "    for project_dim in [1,5,10,50,100,500,1000,2000,4096][::-1]:\n",
    "        num_perturbed = 4\n",
    "        infl_name = f\"fraction_checked_infl_{model_name}_{num_perturbed}_{project_dim}.npy\"\n",
    "        fraction_checked_infl = np.load(os.path.join(save_dir, infl_name))\n",
    "\n",
    "        wall_time_name = f\"wall_time_infl_{model_name}_{num_perturbed}_{project_dim}.npy\"\n",
    "        wall_time = np.load(os.path.join(save_dir, wall_time_name))\n",
    "\n",
    "        all_fraction_checked_infl.append(fraction_checked_infl)\n",
    "        all_wall_time.append(wall_time)\n",
    "\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    # plt.plot(all_fraction_checked_infl[0])\n",
    "    all_wall_time = np.array(all_wall_time)\n",
    "    all_wall_time_mean = all_wall_time.mean(axis=2).mean(axis=1)\n",
    "    all_wall_time_ste = all_wall_time.mean(axis=2).std(axis=1) / np.sqrt(all_wall_time.shape[1])\n",
    "    lns1 = ax.plot(all_wall_time_mean, color=\"brown\", label=\"Wall time\", linestyle=linestyles[0])\n",
    "    ax.fill_between(range(len(all_wall_time_mean)), all_wall_time_mean-all_wall_time_ste, all_wall_time_mean+all_wall_time_ste, alpha=0.2, color=\"brown\")\n",
    "    ax.set_xlabel(\"Project dimension\")\n",
    "    # ax.set_ylabel(\"Wall time\")\n",
    "    ax.set_xticks([0,2,4,6,7,8], [1,5,10,100,1000,4096][::-1])\n",
    "\n",
    "\n",
    "\n",
    "    # create a second y-axis\n",
    "    ax2 = ax.twinx()\n",
    "    all_fraction_checked_infl = np.array(all_fraction_checked_infl)\n",
    "    all_fraction_checked_infl_mean = all_fraction_checked_infl.mean(axis=2).mean(axis=1)\n",
    "    all_fraction_checked_infl_ste = all_fraction_checked_infl.mean(axis=2).std(axis=1) / np.sqrt(all_fraction_checked_infl.shape[1])\n",
    "    lns2 = ax2.plot(all_fraction_checked_infl_mean, color='orange', label=\"AUC\", linestyle=linestyles[1])\n",
    "    ax2.fill_between(range(len(all_fraction_checked_infl_mean)), all_fraction_checked_infl_mean-all_fraction_checked_infl_ste, all_fraction_checked_infl_mean+all_fraction_checked_infl_ste, alpha=0.2, color=\"orange\")\n",
    "    # ax2.set_ylabel(\"AUC\")\n",
    "    ax.set_xticks([0,2,4,6,7,8], [1,5,10,100,1000,4096][::-1])\n",
    "\n",
    "    lns = lns1 + lns2\n",
    "    labels = [l.get_label() for l in lns]\n",
    "\n",
    "    plt.legend(lns, labels, loc=0)\n",
    "\n",
    "    plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_fraction_checked_proj.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm_layers\" in save_dir:\n",
    "\n",
    "    num_remove = 10\n",
    "    corrupt = False\n",
    "\n",
    "    all_acc_orig = []\n",
    "    all_acc_rem_high = []\n",
    "    all_acc_rem_low = []\n",
    "    all_acc_rem_random = []\n",
    "\n",
    "    all_acc_orig_ste = []\n",
    "    all_acc_rem_high_ste = []\n",
    "    all_acc_rem_low_ste = []\n",
    "    all_acc_rem_random_ste = []\n",
    "\n",
    "    for layer_num in [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]:\n",
    "        acc_orig_name = f\"acc_orig_{model_name}_{num_remove}_{layer_num}_{'corrupt' if corrupt else 'remove'}.npy\"\n",
    "        acc_orig = np.load(os.path.join(save_dir, acc_orig_name))\n",
    "\n",
    "        acc_rem_high_name = f\"acc_rem_high_{model_name}_{num_remove}_{layer_num}_{'corrupt' if corrupt else 'remove'}.npy\"\n",
    "        acc_rem_high = np.load(os.path.join(save_dir, acc_rem_high_name))\n",
    "\n",
    "        acc_rem_low_name = f\"acc_rem_low_{model_name}_{num_remove}_{layer_num}_{'corrupt' if corrupt else 'remove'}.npy\"\n",
    "        acc_rem_low = np.load(os.path.join(save_dir, acc_rem_low_name))\n",
    "\n",
    "        acc_rem_random_name = f\"acc_rem_random_{model_name}_{num_remove}_{layer_num}_{'corrupt' if corrupt else 'remove'}.npy\"\n",
    "        acc_rem_random = np.load(os.path.join(save_dir, acc_rem_random_name))\n",
    "\n",
    "        all_acc_orig.append(acc_orig.mean())\n",
    "        all_acc_rem_high.append(acc_rem_high.mean())\n",
    "        all_acc_rem_low.append(acc_rem_low.mean())\n",
    "        all_acc_rem_random.append(acc_rem_random.mean())\n",
    "\n",
    "        all_acc_orig_ste.append(acc_orig.std() / np.sqrt(len(acc_orig)))\n",
    "        all_acc_rem_high_ste.append(acc_rem_high.std() / np.sqrt(len(acc_rem_high)))\n",
    "        all_acc_rem_low_ste.append(acc_rem_low.std() / np.sqrt(len(acc_rem_low)))\n",
    "        all_acc_rem_random_ste.append(acc_rem_random.std() / np.sqrt(len(acc_rem_random)))\n",
    "    \n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    # plt.plot(all_acc_orig, label=\"all data\")\n",
    "    plt.plot(all_acc_rem_high, label=f\"{'remove' if not corrupt else 'corrupt'} high\", linestyle=linestyles[0])\n",
    "    plt.fill_between(range(len(all_acc_rem_high)), np.array(all_acc_rem_high)-np.array(all_acc_rem_high_ste), np.array(all_acc_rem_high)+np.array(all_acc_rem_high_ste), alpha=0.2)\n",
    "    plt.plot(all_acc_rem_low, label=f\"{'remove' if not corrupt else 'corrupt'} low\", linestyle=linestyles[1])\n",
    "    plt.fill_between(range(len(all_acc_rem_low)), np.array(all_acc_rem_low)-np.array(all_acc_rem_low_ste), np.array(all_acc_rem_low)+np.array(all_acc_rem_low_ste), alpha=0.2)\n",
    "    plt.plot(all_acc_rem_random, label=f\"{'remove' if not corrupt else 'corrupt'} random\", linestyle=linestyles[2])\n",
    "    plt.fill_between(range(len(all_acc_rem_random)), np.array(all_acc_rem_random)-np.array(all_acc_rem_random_ste), np.array(all_acc_rem_random)+np.array(all_acc_rem_random_ste), alpha=0.2)\n",
    "\n",
    "    plt.xticks(range(0, len(all_acc_rem_high), 2), [1,5,9,13,17,21,25,29])\n",
    "\n",
    "    plt.xlabel(\"Layer number\")\n",
    "    plt.ylabel(\"Accuracy\")\n",
    "    plt.legend(loc=\"lower left\")\n",
    "\n",
    "    plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_acc_layers_{'remove' if not corrupt else 'corrupt'}.pdf\", bbox_inches=\"tight\", dpi=300)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm_order\" in save_dir:\n",
    "    num_perms = 8\n",
    "    data_name = f\"all_infl_{num_perms}_{model_name}.npy\"\n",
    "    all_infl = np.load(os.path.join(save_dir, data_name))\n",
    "\n",
    "    # min-max normalize the influence for each permutation\n",
    "    for i in range(all_infl.shape[0]):\n",
    "        all_infl[i] = (all_infl[i] - all_infl[i].min()) / (all_infl[i].max() - all_infl[i].min() + 1e-9)\n",
    "\n",
    "    all_infl_mean = all_infl.mean(axis=0)\n",
    "    all_infl_ste = all_infl.std(axis=0) / np.sqrt(len(all_infl))\n",
    "\n",
    "\n",
    "    print(list(all_infl_mean.argsort()))\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    plt.plot(all_infl_mean)\n",
    "    plt.fill_between(range(len(all_infl_mean)), all_infl_mean-all_infl_ste, all_infl_mean+all_infl_ste, alpha=0.2)\n",
    "    plt.xlabel(\"Sample position\")\n",
    "    plt.ylabel(\"Normalized inf.\")\n",
    "    plt.xticks([0,4,8,12,16,19], [0,4,8,12,16,19])\n",
    "\n",
    "    # plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_infl_layers.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm_order\" in save_dir:\n",
    "    save_dirs = [\n",
    "        'results/cls_llm_order_rotten_tomatoes',\n",
    "        'results/cls_llm_order_subj',\n",
    "        'results/cls_llm_order_sst2',\n",
    "    ]\n",
    "\n",
    "    all_infl_mean = []\n",
    "    all_infl_ste = []\n",
    "\n",
    "    for save_dir in save_dirs:\n",
    "        num_perms = 8\n",
    "        data_name = f\"all_infl_{num_perms}_{model_name}.npy\"\n",
    "        all_infl = np.load(os.path.join(save_dir, data_name))\n",
    "\n",
    "        print(all_infl)\n",
    "\n",
    "        # min-max normalize the influence for each permutation\n",
    "        for i in range(all_infl.shape[0]):\n",
    "            all_infl[i] = (all_infl[i] - all_infl[i].min()) / (all_infl[i].max() - all_infl[i].min() + 1e-9)\n",
    "\n",
    "        all_infl_mean.append(all_infl.mean(axis=0))\n",
    "        all_infl_ste.append(all_infl.std(axis=0) / np.sqrt(len(all_infl)))\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    names = [\n",
    "        'Rotten Tomatoes',\n",
    "        'Subj',\n",
    "        'SST-2',\n",
    "    ]\n",
    "\n",
    "    for i, (mean, ste) in enumerate(zip(all_infl_mean, all_infl_ste)):\n",
    "        ax.plot(mean, label=names[i], marker='o', markersize=12, linestyle=linestyles[i])\n",
    "        ax.fill_between(range(len(mean)), mean-ste, mean+ste, alpha=0.2)\n",
    "\n",
    "    ax.set_xlabel(\"Pos. of perturbed demo\")\n",
    "    ax.set_ylabel(\"Test accuracy\")\n",
    "\n",
    "    ax.set_xticks([0,4,8,12,16,19], [0,4,8,12,16,19])\n",
    "\n",
    "    plt.legend()\n",
    "\n",
    "    plt.savefig(f\"figs/cls_llm_order_{model_name}_acc_pos.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"cls_llm_pos\" in save_dir:\n",
    "    num_optimize = 6\n",
    "    infl_data_name = f\"all_infl_acc_{model_name}_{num_optimize}.npy\"\n",
    "    all_infl = np.load(os.path.join(save_dir, infl_data_name))\n",
    "    orig_data_median_name = f\"all_orig_acc_median_{model_name}_{num_optimize}.npy\"\n",
    "    all_orig_median = np.load(os.path.join(save_dir, orig_data_median_name))\n",
    "    # all_infl_coeff_name = f\"all_infl_coeff_{model_name}_{score}_{num_optimize}.npy\"\n",
    "    # all_infl_coeff = np.load(os.path.join(save_dir, all_infl_coeff_name))\n",
    "\n",
    "    # all_infl_mean = all_infl.mean(axis=0)\n",
    "    # all_orig_mean = all_orig_median.mean(axis=0)\n",
    "\n",
    "    # print(f\"all_infl_mean: {all_infl_mean}; all_orig_mean: {all_orig_mean}\")\n",
    "\n",
    "    diff = all_infl - all_orig_median\n",
    "    # max_diff = all_orig_max - all_orig_median\n",
    "\n",
    "    # plot the results\n",
    "    from utils import set_up_plotting\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    set_up_plotting()\n",
    "\n",
    "    plt.boxplot(diff.T)\n",
    "    print(\"diff mean:\", diff.mean(axis=0), diff.std(axis=0) / np.sqrt(len(diff)))\n",
    "    print(\"infl acc:\", all_infl.mean(axis=0), all_infl.std(axis=0) / np.sqrt(len(all_infl)))\n",
    "    print(\"original acc:\", all_orig_median.mean(axis=0), all_orig_median.std(axis=0) / np.sqrt(len(all_orig_median)))\n",
    "    # print(all_infl_coeff.mean(axis=0), all_infl_coeff.std(axis=0) / np.sqrt(len(all_infl_coeff)))\n",
    "\n",
    "    # plt.savefig(f\"figs/{save_dir.split('/')[1]}_{model_name}_infl_layers.pdf\", bbox_inches=\"tight\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
