{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from task import get_acts, get_acts_pca\n",
    "from days_of_week_task import DaysOfWeekTask\n",
    "from months_of_year_task import MonthsOfYearTask\n",
    "import os\n",
    "from adjustText import adjust_text\n",
    "import dill as pickle\n",
    "import matplotlib\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.decomposition import PCA\n",
    "import scipy.stats\n",
    "import pandas as pd\n",
    "from utils import BASE_DIR\n",
    "import torch\n",
    "\n",
    "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Two PCA plots\n",
    "\n",
    "\n",
    "s = 2\n",
    "\n",
    "fig = plt.figure(figsize=(2.75, 1.25))\n",
    "ax1 = plt.subplot(1, 2, 1)\n",
    "ax2 = plt.subplot(1, 2, 2)\n",
    "\n",
    "\n",
    "for ax in [ax1, ax2]:\n",
    "    ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    # ax.spines['bottom'].set_visible(False)\n",
    "    # ax.spines['left'].set_visible(False)\n",
    "    ax.spines[\"left\"].set_position(\"zero\")\n",
    "    ax.spines[\"bottom\"].set_position(\"zero\")\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "\n",
    "# Left plot\n",
    "task = DaysOfWeekTask(\"cpu\", \"mistral\")\n",
    "problems = task.generate_problems()\n",
    "tokens = task.allowable_tokens\n",
    "acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]\n",
    "\n",
    "texts = []\n",
    "for token_index, token in enumerate(tokens):\n",
    "    indices = [i for i, p in enumerate(problems) if p.info[0] == token_index]\n",
    "    ax1.scatter(acts[indices, 0], acts[indices, 1], label=token, s=s)\n",
    "\n",
    "    # Add label\n",
    "    texts.append(\n",
    "        ax1.text(acts[indices[0], 0], acts[indices[0], 1], token[:3], fontsize=8)\n",
    "    )\n",
    "\n",
    "adjust_text(texts, ax=ax1, force_text=(0.5, 1))\n",
    "\n",
    "# Move thursday up\n",
    "thursday_text = texts[3]\n",
    "thursday_pos = texts[3].get_position()\n",
    "texts[3].set_position(thursday_pos + np.array([0, 1.2]))\n",
    "\n",
    "# Move sunday down\n",
    "sunday_text = texts[6]\n",
    "sunday_pos = texts[6].get_position()\n",
    "texts[6].set_position(sunday_pos + np.array([0, -0.5]))\n",
    "\n",
    "ax1.set_xlim(-8, 8)\n",
    "ax1.set_ylim(-8, 8)\n",
    "\n",
    "# Right plot\n",
    "task = MonthsOfYearTask(\"cpu\", \"llama\")\n",
    "problems = task.generate_problems()\n",
    "tokens = task.allowable_tokens\n",
    "acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]\n",
    "colorwheel = plt.cm.rainbow(np.linspace(0, 1 - 1 / 12, 12))\n",
    "\n",
    "texts = []\n",
    "for token_index, token in enumerate(tokens):\n",
    "    indices = [i for i, p in enumerate(problems) if p.info[0] == token_index]\n",
    "    ax2.scatter(acts[indices, 0], acts[indices, 1], s=s, color=colorwheel[token_index])\n",
    "\n",
    "    # Add label\n",
    "    texts.append(\n",
    "        ax2.text(acts[indices[0], 0], acts[indices[0], 1], token[:3], fontsize=8)\n",
    "    )\n",
    "\n",
    "adjust_text(texts, ax=ax2, force_text=(0.25, 0.5))\n",
    "\n",
    "ax2.set_xlim(-0.6, 0.6)\n",
    "ax2.set_ylim(-0.6, 0.6)\n",
    "\n",
    "x_line = ax1.get_position().x1 + 0.025\n",
    "fig.add_artist(\n",
    "    Line2D(\n",
    "        [x_line, x_line],\n",
    "        [0.05, 0.95],\n",
    "        transform=fig.transFigure,\n",
    "        color=\"grey\",\n",
    "        linewidth=0.5,\n",
    "    )\n",
    ")\n",
    "\n",
    "\n",
    "plt.tight_layout(pad=1)\n",
    "\n",
    "fig = plt.gcf()\n",
    "\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\"figs/paper_plots/paper_pcas.pdf\", bbox_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {\"size\": 6}\n",
    "\n",
    "matplotlib.rc(\"font\", **font)\n",
    "\n",
    "width = 0.4\n",
    "\n",
    "layer_averages_files = [\n",
    "    \"figs/mistral_days_of_week/rotation_probing/a_cos_sin_all_layers_token_15_mean_logit_diffs.pkl\",\n",
    "    \"figs/mistral_months_of_year/rotation_probing/a_cos_sin_all_layers_token_13_mean_logit_diffs.pkl\",\n",
    "    \"figs/llama_days_of_week/rotation_probing/a_cos_sin_all_layers_token_14_mean_logit_diffs.pkl\",\n",
    "    \"figs/llama_months_of_year/rotation_probing/a_cos_sin_all_layers_token_11_mean_logit_diffs.pkl\",\n",
    "]\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(3.3, 2), sharex=True)\n",
    "axs = axs.flatten()\n",
    "\n",
    "max_layer = 25\n",
    "\n",
    "titles = [\"Mistral Weekdays\", \"Mistral Months\", \"Llama Weekdays\", \"Llama Months\"]\n",
    "\n",
    "for i, layer_averages_file in enumerate(layer_averages_files):\n",
    "    layer_averages = pickle.load(open(layer_averages_file, \"rb\"))\n",
    "    intervention_pca_k = 5\n",
    "    layers = []\n",
    "    layer_averages = layer_averages[1::2]\n",
    "\n",
    "    # average_no_replace.append(t[2])\n",
    "    # average_replace_circle.append(t[3])\n",
    "    # average_replace_pca.append(t[4])\n",
    "    # average_replace_all.append(t[5])\n",
    "    # average_average_ablate.append(t[6])\n",
    "    # average_zero_circle.append(t[7])\n",
    "    # average_zero_everything_but_circle.append(t[8])\n",
    "\n",
    "    index_to_name = {\n",
    "        2: \"no_replace\",\n",
    "        3: \"replace_circle\",\n",
    "        4: \"replace_pca\",\n",
    "        5: \"replace_all\",\n",
    "        6: \"average_ablate\",\n",
    "        7: \"zero_circle\",\n",
    "        8: \"zero_everything_but_circle\",\n",
    "    }\n",
    "\n",
    "    average_dict = {key: [] for key in index_to_name.values()}\n",
    "    ci_dict = {key: [] for key in index_to_name.values()}\n",
    "\n",
    "    def mean_confidence_interval(data, confidence=0.96):\n",
    "        a = 1.0 * np.array(data)\n",
    "        n = len(a)\n",
    "        m, se = np.mean(a), scipy.stats.sem(a)\n",
    "        h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1)\n",
    "        return m, m - h, m + h\n",
    "\n",
    "    for t in layer_averages:\n",
    "        if t[1] == intervention_pca_k:\n",
    "            layers.append(t[0])\n",
    "            for index in range(2, 9):\n",
    "                mean, lower, upper = mean_confidence_interval(t[index])\n",
    "                average_dict[index_to_name[index]].append(mean)\n",
    "                ci_dict[index_to_name[index]].append((lower, upper))\n",
    "\n",
    "    plot_dict = {\n",
    "        \"no_replace\": \"No-op\",\n",
    "        \"replace_all\": \"Patch layer\",\n",
    "        \"replace_circle\": \"Patch circle\",\n",
    "        \"replace_pca\": \"Patch pca\",\n",
    "        \"average_ablate\": \"Average ablate\",\n",
    "    }\n",
    "\n",
    "    for key, label in plot_dict.items():\n",
    "        axs[i].plot(\n",
    "            layers[:max_layer],\n",
    "            average_dict[key][:max_layer],\n",
    "            label=label,\n",
    "            linewidth=width,\n",
    "            alpha=0.9,\n",
    "        )\n",
    "\n",
    "        # Plot ci\n",
    "        lower = [ci[0] for ci in ci_dict[key][:max_layer]]\n",
    "        upper = [ci[1] for ci in ci_dict[key][:max_layer]]\n",
    "        axs[i].fill_between(layers[:max_layer], lower, upper, alpha=0.6)\n",
    "\n",
    "    axs[i].set_title(titles[i])\n",
    "\n",
    "\n",
    "def format_subplot(ax, grid_x=True):\n",
    "    ax.spines[\"top\"].set_visible(False)\n",
    "    ax.spines[\"right\"].set_visible(False)\n",
    "    if grid_x:\n",
    "        ax.grid(linestyle=\"--\", alpha=0.4)\n",
    "    else:\n",
    "        ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "\n",
    "for ax in axs:\n",
    "    format_subplot(ax)\n",
    "\n",
    "xlabel = fig.supxlabel(\"Layer\")\n",
    "ylabel = fig.supylabel(\"Average logit diff\")\n",
    "\n",
    "# Move xlabel up\n",
    "xlabel.set_position((0.55, 0.06))\n",
    "\n",
    "# Move ylabel right\n",
    "ylabel.set_position((0.01, 0.5))\n",
    "\n",
    "fig = plt.gcf()\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "\n",
    "colorwheel = plt.cm.tab10(np.linspace(0, 1, 10))\n",
    "linesize = 0.01\n",
    "markersize = 0\n",
    "legend_elements = [\n",
    "    Line2D(\n",
    "        [0, linesize],\n",
    "        [0, 0],\n",
    "        marker=\"o\",\n",
    "        color=colorwheel[0],\n",
    "        label=\"No-op\",\n",
    "        markerfacecolor=colorwheel[0],\n",
    "        markersize=markersize,\n",
    "    ),\n",
    "    Line2D(\n",
    "        [0, linesize],\n",
    "        [0, 0],\n",
    "        marker=\"o\",\n",
    "        color=colorwheel[1],\n",
    "        label=\"Patch layer\",\n",
    "        markerfacecolor=colorwheel[1],\n",
    "        markersize=markersize,\n",
    "    ),\n",
    "    Line2D(\n",
    "        [0, linesize],\n",
    "        [0, 0],\n",
    "        marker=\"o\",\n",
    "        color=colorwheel[2],\n",
    "        label=\"Patch circle\",\n",
    "        markerfacecolor=colorwheel[2],\n",
    "        markersize=markersize,\n",
    "    ),\n",
    "    Line2D(\n",
    "        [0, linesize],\n",
    "        [0, 0],\n",
    "        marker=\"o\",\n",
    "        color=colorwheel[3],\n",
    "        label=\"Patch PCA\",\n",
    "        markerfacecolor=colorwheel[3],\n",
    "        markersize=markersize,\n",
    "    ),\n",
    "    Line2D(\n",
    "        [0, linesize],\n",
    "        [0, 0],\n",
    "        marker=\"o\",\n",
    "        color=colorwheel[4],\n",
    "        label=\"Average ablate\",\n",
    "        markerfacecolor=colorwheel[4],\n",
    "        markersize=markersize,\n",
    "    ),\n",
    "]\n",
    "\n",
    "leg = fig.legend(\n",
    "    handles=legend_elements,\n",
    "    loc=\"upper center\",\n",
    "    ncol=3,\n",
    "    bbox_to_anchor=(0.52, 0.06),\n",
    "    labelspacing=0,\n",
    "    handletextpad=0.3,\n",
    "    columnspacing=1,\n",
    "    handlelength=0.8,\n",
    ")\n",
    "for legobj in leg.legendHandles:\n",
    "    legobj.set_linewidth(1.5)\n",
    "\n",
    "fig.add_artist(\n",
    "    Line2D(\n",
    "        [0.53, 0.53], [0.15, 0.99], transform=fig.transFigure, color=\"grey\", linewidth=1\n",
    "    )\n",
    ")\n",
    "fig.add_artist(\n",
    "    Line2D([0.1, 1], [0.58, 0.58], transform=fig.transFigure, color=\"grey\", linewidth=1)\n",
    ")\n",
    "\n",
    "\n",
    "plt.tight_layout(pad=0.7)\n",
    "\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "fig.savefig(\n",
    "    f\"figs/paper_plots/combined_intervention.pdf\",\n",
    "    bbox_inches=\"tight\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from intervene_in_middle_of_circle import get_points\n",
    "\n",
    "# from matplotlib.colors import tab20\n",
    "\n",
    "font = {\"size\": 5}\n",
    "\n",
    "matplotlib.rc(\"font\", **font)\n",
    "\n",
    "s = 0.1\n",
    "\n",
    "\n",
    "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n",
    "layer = 5\n",
    "token = task.a_token\n",
    "durations = range(2, 6)\n",
    "circle_letter = \"a\"\n",
    "pca_k = 5\n",
    "circle_size = len(task.allowable_tokens)\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(2.2, 2.2))\n",
    "\n",
    "for ax, duration in zip(axs.flatten(), durations):\n",
    "    filename = f\"figs/{task.name}/varying_circle/logits_{layer}_{token}_{pca_k}_{duration}_{circle_letter}.npy\"\n",
    "    all_logits = np.load(filename)\n",
    "    all_points, angles, radius_vals = get_points()\n",
    "    best_a = np.argmax(all_logits, axis=1)\n",
    "    for i in range(circle_size):\n",
    "        ax.scatter(\n",
    "            all_points[best_a == i, 0],\n",
    "            all_points[best_a == i, 1],\n",
    "            label=task.allowable_tokens[i],\n",
    "            s=s,\n",
    "            # color=tab20(i)\n",
    "        )\n",
    "    ax.set_title(f\"Task Duration = {duration} Days\")\n",
    "    # ax.legend()\n",
    "    ax.set_xlim(-2, 2)\n",
    "    ax.set_ylim(-2, 2)\n",
    "    ax.set_aspect(\"equal\", adjustable=\"box\")\n",
    "\n",
    "    # Plot unit circle of size 7 in black\n",
    "    # ax.plot(np.cos(np.arange(0, 7) * 2 * np.pi / 7), np.sin(np.arange(0, 7) * 2 * np.pi / 7), 'o', color=\"black\")\n",
    "\n",
    "handles, labels = axs[0][0].get_legend_handles_labels()\n",
    "\n",
    "\n",
    "# plt.suptitle(\"Highest Logit Day After Intervention\")\n",
    "\n",
    "fig = plt.gcf()\n",
    "\n",
    "fig.tight_layout()\n",
    "\n",
    "# Put legend below figure\n",
    "lgnd = fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc=\"upper center\",\n",
    "    ncol=4,\n",
    "    bbox_to_anchor=(0.5, 0.02),\n",
    "    fontsize=5,\n",
    "    frameon=False,\n",
    "    columnspacing=0,\n",
    ")\n",
    "for i in range(circle_size):\n",
    "    lgnd.legendHandles[i]._sizes = [10]\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\n",
    "    f\"figs/paper_plots/mistral_highest_logit_day_after_intervention.pdf\",\n",
    "    bbox_extra_artists=(lgnd,),\n",
    "    bbox_inches=\"tight\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {\"size\": 7}\n",
    "\n",
    "matplotlib.rc(\"font\", **font)\n",
    "\n",
    "s = 2\n",
    "\n",
    "fig = plt.figure(figsize=(1.65, 1.5))\n",
    "ax = plt.gca()\n",
    "\n",
    "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n",
    "acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)\n",
    "\n",
    "problems = task.generate_problems()\n",
    "a = np.array([problem.info[0] for problem in problems])\n",
    "b = np.array([problem.info[1] for problem in problems])\n",
    "c = np.array([problem.info[2] for problem in problems])\n",
    "\n",
    "explaining_variables = []\n",
    "for i in range(7):\n",
    "    explaining_variables.append(a == i)\n",
    "for i in range(1, 8):\n",
    "    explaining_variables.append(b == i)\n",
    "\n",
    "explaining_variables = np.array(explaining_variables).T\n",
    "\n",
    "print(explaining_variables.shape)\n",
    "\n",
    "least_squares_sol = np.linalg.lstsq(explaining_variables, acts)[0]\n",
    "\n",
    "residuals = acts - explaining_variables @ least_squares_sol\n",
    "\n",
    "pca = PCA(n_components=2)\n",
    "pca.fit(residuals)\n",
    "\n",
    "print(pca.explained_variance_ratio_)\n",
    "\n",
    "projected = pca.transform(residuals)\n",
    "\n",
    "for day_of_week in range(7):\n",
    "    ax.plot(\n",
    "        projected[c == day_of_week, 0],\n",
    "        projected[c == day_of_week, 1],\n",
    "        \"o\",\n",
    "        label=task.allowable_tokens[day_of_week],\n",
    "        markersize=s,\n",
    "    )\n",
    "\n",
    "# ax.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.1), ncol=3)\n",
    "\n",
    "ax.tick_params(axis=\"both\", which=\"major\", labelsize=5)\n",
    "ax.spines[\"top\"].set_visible(False)\n",
    "ax.spines[\"right\"].set_visible(False)\n",
    "# ax.spines['bottom'].set_visible(False)\n",
    "# ax.spines['left'].set_visible(False)\n",
    "ax.spines[\"left\"].set_position(\"zero\")\n",
    "ax.spines[\"bottom\"].set_position(\"zero\")\n",
    "# ax.set_xticks([])\n",
    "# ax.set_yticks([])\n",
    "\n",
    "positions = [\n",
    "    [-4.5, 0.5],\n",
    "    [-3, 2.2],\n",
    "    [0.8, 2.3],\n",
    "    [2.5, 0.7],\n",
    "    [2.3, -1.0],\n",
    "    [0.5, -2.6],\n",
    "    [-3.5, -2.6],\n",
    "]\n",
    "\n",
    "# Add text to plot\n",
    "texts = []\n",
    "for i in range(7):\n",
    "    x, y = positions[i]\n",
    "    # texts.append(ax.text(x, y, rf'$\\gamma$ = {task.allowable_tokens[i][:3]}', fontsize=6))\n",
    "    texts.append(ax.text(x, y, rf\"{task.allowable_tokens[i][:3]}\", fontsize=6))\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\n",
    "    f\"figs/paper_plots/mistral_residuals_c_pca.pdf\",\n",
    "    bbox_inches=\"tight\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Performance table\n",
    "\n",
    "for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
    "    for model_name in [\"mistral\", \"llama\"]:\n",
    "        results = pd.read_csv(\n",
    "            f\"{BASE_DIR}/{model_name}_{task_name}/results.csv\", skipinitialspace=True\n",
    "        )\n",
    "        number_correct = results[\"best_token\"] == results[\"ground_truth\"]\n",
    "        print(task_name, model_name, np.sum(number_correct))\n",
    "\n",
    "# GPT 2\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "model = HookedTransformer.from_pretrained(\"gpt2\")\n",
    "\n",
    "for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
    "    if task_name == \"days_of_week\":\n",
    "        task = DaysOfWeekTask(\"cpu\", model_name=\"gpt2\")\n",
    "    else:\n",
    "        task = MonthsOfYearTask(\"cpu\", model_name=\"gpt2\")\n",
    "    problems = task.generate_problems()\n",
    "    answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]\n",
    "    num_correct = 0\n",
    "    for problem in problems:\n",
    "        logits = model(problem.prompt).cpu()[0][-1]\n",
    "        top_from_answer_logits = np.argmax(logits[answer_logits])\n",
    "        correct_answer = problem.info[2]\n",
    "        if top_from_answer_logits == correct_answer:\n",
    "            num_correct += 1\n",
    "    print(task_name, \"gpt2\", num_correct)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "multiplexing",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
