{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from task import get_acts, get_acts_pca, get_all_acts\n",
    "from days_of_week_task import DaysOfWeekTask, days_of_week\n",
    "from months_of_year_task import MonthsOfYearTask, months_of_year\n",
    "import os\n",
    "from adjustText import adjust_text\n",
    "import dill as pickle\n",
    "import matplotlib\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.decomposition import PCA\n",
    "import scipy.stats\n",
    "from task import activation_patching\n",
    "import torch\n",
    "from utils import BASE_DIR\n",
    "import pandas as pd\n",
    "from circle_finding_utils import do_regression\n",
    "import einops\n",
    "\n",
    "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n",
    "\n",
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# All PCA plots\n",
    "\n",
    "s = 0.01\n",
    "\n",
    "font = {\"size\": 4}\n",
    "\n",
    "matplotlib.rc(\"font\", **font)\n",
    "\n",
    "for task_name in [\"weekdays\", \"months\"]:\n",
    "    for model_name in [\"llama\", \"mistral\"]:\n",
    "        if task_name == \"weekdays\":\n",
    "            task = DaysOfWeekTask(device, model_name)\n",
    "            tokens = task.allowable_tokens\n",
    "        else:\n",
    "            task = MonthsOfYearTask(device, model_name)\n",
    "            tokens = task.allowable_tokens\n",
    "            colorwheel = plt.cm.rainbow(np.linspace(0, 1 - 1 / 12, 12))\n",
    "\n",
    "        problems = task.generate_problems()\n",
    "        tokens = task.allowable_tokens\n",
    "\n",
    "        fig, axs = plt.subplots(4, 8, figsize=(2.75, 1.5))\n",
    "        axs = axs.flatten()\n",
    "\n",
    "        for layer in range(32):\n",
    "            acts = get_acts_pca(task, layer=layer, token=task.a_token, pca_k=2)[0]\n",
    "\n",
    "            texts = []\n",
    "            for token_index, token in enumerate(tokens):\n",
    "                ax = axs[layer]\n",
    "\n",
    "                indices = [\n",
    "                    i for i, p in enumerate(problems) if p.info[0] == token_index\n",
    "                ]\n",
    "                if task_name == \"weekdays\":\n",
    "                    ax.scatter(acts[indices, 0], acts[indices, 1], label=token, s=s)\n",
    "                else:\n",
    "                    ax.scatter(\n",
    "                        acts[indices, 0],\n",
    "                        acts[indices, 1],\n",
    "                        label=token,\n",
    "                        s=s,\n",
    "                        color=colorwheel[token_index],\n",
    "                    )\n",
    "\n",
    "                ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
    "                ax.spines[\"top\"].set_visible(False)\n",
    "                ax.spines[\"right\"].set_visible(False)\n",
    "                ax.spines[\"left\"].set_position(\"zero\")\n",
    "                ax.spines[\"bottom\"].set_position(\"zero\")\n",
    "                ax.set_xticks([])\n",
    "                ax.set_yticks([])\n",
    "                ax.spines[\"top\"].set_linewidth(0.2)\n",
    "                ax.spines[\"right\"].set_linewidth(0.2)\n",
    "                ax.spines[\"left\"].set_linewidth(0.2)\n",
    "                ax.spines[\"bottom\"].set_linewidth(0.2)\n",
    "\n",
    "                ax.set_title(f\"Layer {layer}\")\n",
    "\n",
    "        handles, labels = axs[0].get_legend_handles_labels()\n",
    "        ncol = 6 if (task_name == \"months\") else 7\n",
    "        legend = fig.legend(\n",
    "            handles,\n",
    "            labels,\n",
    "            loc=\"upper center\",\n",
    "            ncol=ncol,\n",
    "            fontsize=3,\n",
    "            bbox_to_anchor=(0.5, 0),\n",
    "            frameon=False,\n",
    "        )\n",
    "\n",
    "        for i in range(len(legend.legendHandles)):\n",
    "            legend.legendHandles[i]._sizes = [2]\n",
    "\n",
    "        plt.tight_layout()\n",
    "\n",
    "        fig.savefig(\n",
    "            f\"figs/paper_plots/{model_name}_{task_name}_all_pca.pdf\",\n",
    "            bbox_inches=\"tight\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run all patching experiments\n",
    "\n",
    "for model_name in [\"mistral\", \"llama\"]:\n",
    "    for task_name in [\"task_name\", \"months_of_year\"]:\n",
    "        if task_name == \"{task_name}\":\n",
    "            task = DaysOfWeekTask(model_name=model_name, device=device)\n",
    "        else:\n",
    "            task = MonthsOfYearTask(model_name=model_name, device=device)\n",
    "\n",
    "        for keep_same_index in [0, 1]:\n",
    "            for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n",
    "                if layer_type == \"attention_head\":\n",
    "                    patching_sweep = [task.before_c_token, (12, 32)]\n",
    "                    activation_patching(\n",
    "                        task,\n",
    "                        keep_same_index=keep_same_index,\n",
    "                        num_chars_in_answer_to_include=0,\n",
    "                        num_activation_patching_experiments_to_run=20,\n",
    "                        layer_type=layer_type,\n",
    "                        patching_sweep=patching_sweep,\n",
    "                    )\n",
    "                else:\n",
    "                    activation_patching(\n",
    "                        task,\n",
    "                        keep_same_index=keep_same_index,\n",
    "                        num_chars_in_answer_to_include=0,\n",
    "                        num_activation_patching_experiments_to_run=20,\n",
    "                        layer_type=layer_type,\n",
    "                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Patching plots\n",
    "\n",
    "# Set plotting sizes\n",
    "SMALL_SIZE = 24\n",
    "MEDIUM_SIZE = 24\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\n",
    "    \"axes\", labelsize=SMALL_SIZE\n",
    ")  # fontsize of the x and y labels for the small plots\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\n",
    "    \"figure\", labelsize=MEDIUM_SIZE\n",
    ")  # fontsize of the x and y labels for the big plots\n",
    "\n",
    "\n",
    "for model_name in [\"mistral\", \"llama\"]:\n",
    "    for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
    "        if task_name == \"days_of_week\":\n",
    "            task = DaysOfWeekTask(device, model_name=model_name)\n",
    "        else:\n",
    "            task = MonthsOfYearTask(device, model_name=model_name)\n",
    "\n",
    "        for patching_type in [\"mlp\", \"attention\"]:\n",
    "            fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "            patching_data_same_a = np.load(\n",
    "                f\"figs/{model_name}_{task_name}/patching/{patching_type}/keep-same0_chars-in-answer0_n20.npy\"\n",
    "            )\n",
    "            patching_data_same_b = np.load(\n",
    "                f\"figs/{model_name}_{task_name}/patching/{patching_type}/keep-same1_chars-in-answer0_n20.npy\"\n",
    "            )\n",
    "\n",
    "            combined = np.concatenate(\n",
    "                [patching_data_same_a, patching_data_same_b], axis=0\n",
    "            )\n",
    "\n",
    "            average_patching_data = np.mean(combined, axis=0)\n",
    "\n",
    "            ending_token_excl = max(task.token_map.keys())\n",
    "            starting_token_incl = min(task.token_map.keys())\n",
    "\n",
    "            average_patching_data = average_patching_data[\n",
    "                :, starting_token_incl:ending_token_excl\n",
    "            ]\n",
    "\n",
    "            ax.set_yticks(range(starting_token_incl, ending_token_excl))\n",
    "\n",
    "            if task_name == \"days_of_week\":\n",
    "                ax.set_yticklabels(\n",
    "                    [\"*Monday\", \"is\", \"two\", \"days\", \"from\"][::-1], ha=\"right\"\n",
    "                )\n",
    "            else:\n",
    "                ax.set_yticklabels(\n",
    "                    [\"*Two\", \"months\", \"from\", \"*January\", \"is\"][::-1], ha=\"right\"\n",
    "                )\n",
    "\n",
    "            ax.set_xlabel(\"Layer\")\n",
    "\n",
    "            average_patching_data = average_patching_data.T\n",
    "\n",
    "            # Set negatives to 0\n",
    "            average_patching_data[average_patching_data < 0] = 0\n",
    "\n",
    "            im = ax.imshow(\n",
    "                average_patching_data,\n",
    "                cmap=\"OrRd\",\n",
    "                extent=[-0.5, 31.5, starting_token_incl - 0.5, ending_token_excl - 0.5],\n",
    "                aspect=\"auto\",\n",
    "            )\n",
    "\n",
    "            plt.colorbar(im)\n",
    "            fig = plt.gcf()\n",
    "            plt.show()\n",
    "            plt.tight_layout()\n",
    "            fig.savefig(\n",
    "                f\"figs/paper_plots/{model_name}_{patching_type}_{task_name}_patching.pdf\",\n",
    "                bbox_inches=\"tight\",\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Attention head patching plots\n",
    "\n",
    "# Set plotting sizes\n",
    "SMALL_SIZE = 24\n",
    "MEDIUM_SIZE = 24\n",
    "BIGGER_SIZE = 24\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\n",
    "    \"axes\", labelsize=SMALL_SIZE\n",
    ")  # fontsize of the x and y labels for the small plots\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "plt.rc(\n",
    "    \"figure\", labelsize=MEDIUM_SIZE\n",
    ")  # fontsize of the x and y labels for the big plots\n",
    "\n",
    "all_top_heads = {}\n",
    "all_average_patching_data = {}\n",
    "num_heads_to_save_per = 10\n",
    "\n",
    "patching_type = \"attention_head\"\n",
    "\n",
    "for model_name in [\"mistral\", \"llama\"]:\n",
    "    for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
    "        if task_name == \"days_of_week\":\n",
    "            task = DaysOfWeekTask(device, model_name=model_name)\n",
    "        else:\n",
    "            task = MonthsOfYearTask(device, model_name=model_name)\n",
    "\n",
    "        fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "        patching_data_same_a = np.load(\n",
    "            f\"figs/{model_name}_{task_name}/patching/{patching_type}/keep-same0_chars-in-answer0_n20.npy\"\n",
    "        )\n",
    "        patching_data_same_b = np.load(\n",
    "            f\"figs/{model_name}_{task_name}/patching/{patching_type}/keep-same1_chars-in-answer0_n20.npy\"\n",
    "        )\n",
    "\n",
    "        combined = np.concatenate([patching_data_same_a, patching_data_same_b], axis=0)\n",
    "\n",
    "        average_patching_data = np.mean(combined, axis=0)\n",
    "\n",
    "        # Set negatives to 0\n",
    "        average_patching_data[average_patching_data < 0] = 0\n",
    "\n",
    "        all_average_patching_data[(model_name, task_name)] = average_patching_data\n",
    "\n",
    "        im = ax.imshow(\n",
    "            average_patching_data,\n",
    "            cmap=\"OrRd\",\n",
    "            aspect=\"auto\",\n",
    "            extent=[-0.5, 31.5, 31.5, 11.5],\n",
    "        )\n",
    "\n",
    "        # Set ticks\n",
    "        ax.set_xticks(range(0, 32, 4))\n",
    "        ax.set_yticks(range(30, 12, -4))\n",
    "\n",
    "        plt.xlabel(\"Head\")\n",
    "\n",
    "        plt.ylabel(\"Layer\")\n",
    "\n",
    "        plt.colorbar(im)\n",
    "        fig = plt.gcf()\n",
    "        plt.show()\n",
    "        plt.tight_layout()\n",
    "        fig.savefig(\n",
    "            f\"figs/paper_plots/{model_name}_{patching_type}_{task_name}_patching.pdf\",\n",
    "            bbox_inches=\"tight\",\n",
    "        )\n",
    "\n",
    "        v, i = torch.topk(\n",
    "            torch.tensor(average_patching_data.flatten()), num_heads_to_save_per\n",
    "        )\n",
    "        top_heads = np.array(np.unravel_index(i.numpy(), average_patching_data.shape)).T\n",
    "        all_top_heads[(model_name, task_name)] = top_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_r_squared_one_hot_a_b(acts, task):\n",
    "    a = torch.tensor([problem.info[0] for problem in task.generate_problems()])\n",
    "    b = torch.tensor([problem.info[1] for problem in task.generate_problems()])\n",
    "\n",
    "    explanatory_vecs = [torch.ones_like(a)]\n",
    "    explanatory_vecs += [a == i for i in range(8)]\n",
    "    explanatory_vecs += [b == i for i in range(8)]\n",
    "\n",
    "    return do_regression(task, explanatory_vecs, acts, verbose=False)[0]\n",
    "\n",
    "\n",
    "def get_r_squared_one_hot_a_b_c(acts, task):\n",
    "    a = torch.tensor([problem.info[0] for problem in task.generate_problems()])\n",
    "    b = torch.tensor([problem.info[1] for problem in task.generate_problems()])\n",
    "    c = torch.tensor([problem.info[2] for problem in task.generate_problems()])\n",
    "\n",
    "    explanatory_vecs = [torch.ones_like(a)]\n",
    "    explanatory_vecs += [a == i for i in range(8)]\n",
    "    explanatory_vecs += [b == i for i in range(8)]\n",
    "    explanatory_vecs += [c == i for i in range(8)]\n",
    "\n",
    "    return do_regression(task, explanatory_vecs, acts, verbose=False)[0]\n",
    "\n",
    "\n",
    "data = []\n",
    "for model_name, task_name in all_top_heads.keys():\n",
    "    if task_name == \"days_of_week\":\n",
    "        task = DaysOfWeekTask(device, model_name=model_name)\n",
    "    else:\n",
    "        task = MonthsOfYearTask(device, model_name=model_name)\n",
    "\n",
    "    acts = get_all_acts(\n",
    "        task,\n",
    "        save_results_csv=False,\n",
    "        names_filter=\"heads\",\n",
    "        save_file_prefix=\"attn_heads\",\n",
    "        verbose=True,\n",
    "    )\n",
    "\n",
    "    acts = einops.rearrange(\n",
    "        acts,\n",
    "        \"n (layer head_index) token d -> n token layer head_index d\",\n",
    "        head_index=32,\n",
    "    )\n",
    "\n",
    "    top_heads = all_top_heads[(model_name, task_name)]\n",
    "    average_patching_data = all_average_patching_data[(model_name, task_name)]\n",
    "    print(average_patching_data.shape)\n",
    "    for layer, head in top_heads:\n",
    "        average_intervention_effect = average_patching_data[layer, head]\n",
    "\n",
    "        layer += 12\n",
    "\n",
    "        percent_explained_1 = get_r_squared_one_hot_a_b(\n",
    "            acts[:, task.before_c_token, layer, head, :], task\n",
    "        )\n",
    "        percent_explaiend_2 = get_r_squared_one_hot_a_b_c(\n",
    "            acts[:, task.before_c_token, layer, head, :], task\n",
    "        )\n",
    "\n",
    "        data.append(\n",
    "            [\n",
    "                model_name,\n",
    "                task_name,\n",
    "                layer,\n",
    "                head,\n",
    "                average_intervention_effect,\n",
    "                percent_explained_1,\n",
    "                percent_explaiend_2,\n",
    "            ]\n",
    "        )\n",
    "\n",
    "dataframe = pd.DataFrame(\n",
    "    data,\n",
    "    columns=[\n",
    "        \"model\",\n",
    "        \"task\",\n",
    "        \"layer\",\n",
    "        \"head\",\n",
    "        \"average intervention effect\",\n",
    "        \"percent explained by one hot a and b\",\n",
    "        \"percent explained by one hot a and b and c\",\n",
    "    ],\n",
    ")\n",
    "\n",
    "for model_name in [\"mistral\", \"llama\"]:\n",
    "    for task in [\"days_of_week\", \"months_of_year\"]:\n",
    "        filtered_dataframe = dataframe[\n",
    "            (dataframe[\"model\"] == model_name) & (dataframe[\"task\"] == task)\n",
    "        ]\n",
    "\n",
    "        filtered_dataframe.drop(columns=[\"model\", \"task\"], inplace=True)\n",
    "\n",
    "        print(filtered_dataframe)\n",
    "\n",
    "        # Save to latex\n",
    "        filtered_dataframe.to_latex(\n",
    "            f\"figs/paper_plots/{model_name}_{task}_attention_head_patching.tex\",\n",
    "            index=False,\n",
    "            float_format=\"%.2f\",\n",
    "            header=[\n",
    "                \"Layer\",\n",
    "                \"Head\",\n",
    "                \"Average Intervention Effect\",\n",
    "                \"EVR $R^2$, One Hot $\\\\alpha$, $\\\\beta$\",\n",
    "                \"EVR $R^2$, One Hot $\\\\alpha$, $\\\\beta$, $\\\\gamma$\",\n",
    "            ],\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Latex tables of results.csv\n",
    "\n",
    "for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
    "    results_mistral = pd.read_csv(\n",
    "        f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n",
    "    )\n",
    "\n",
    "    results_mistral = results_mistral.rename(\n",
    "        columns={\"best_token\": \"mistral_out\", \"model_correct\": \"mistral_correct\"}\n",
    "    )\n",
    "\n",
    "    results_mistral[\"mistral_correct\"] = (\n",
    "        results_mistral[\"mistral_out\"] == results_mistral[\"ground_truth\"]\n",
    "    )\n",
    "\n",
    "    columns_to_display = [\"a\", \"b\", \"ground_truth\", \"mistral_out\", \"mistral_correct\"]\n",
    "\n",
    "    results_mistral = results_mistral[columns_to_display]\n",
    "\n",
    "    print(sum(results_mistral[\"mistral_correct\"]))\n",
    "\n",
    "    results_llama = pd.read_csv(\n",
    "        f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n",
    "    )\n",
    "\n",
    "    results_llama = results_llama.rename(\n",
    "        columns={\"best_token\": \"llama_out\", \"model_correct\": \"llama_correct\"}\n",
    "    )\n",
    "\n",
    "    results_llama[\"llama_correct\"] = (\n",
    "        results_llama[\"llama_out\"] == results_llama[\"ground_truth\"]\n",
    "    )\n",
    "\n",
    "    print(sum(results_llama[\"llama_correct\"]))\n",
    "\n",
    "    columns_to_display = [\"a\", \"b\", \"ground_truth\", \"llama_out\", \"llama_correct\"]\n",
    "\n",
    "    results_llama = results_llama[columns_to_display]\n",
    "\n",
    "    # Merges dfs\n",
    "    results = pd.merge(results_mistral, results_llama, on=[\"a\", \"b\", \"ground_truth\"])\n",
    "\n",
    "    results = results.sort_values(by=[\"b\", \"a\"])\n",
    "\n",
    "    # Replace NANs with \"<whitespace>\"\n",
    "    results = results.fillna(\"<whitespace>\")\n",
    "\n",
    "    # Map model_correct to yes no\n",
    "    results[\"mistral_correct\"] = results[\"mistral_correct\"].map(\n",
    "        {True: \"Yes\", False: \"No\"}\n",
    "    )\n",
    "    results[\"llama_correct\"] = results[\"llama_correct\"].map({True: \"Yes\", False: \"No\"})\n",
    "\n",
    "    # Drop rows that both models get correct\n",
    "    results = results[\n",
    "        ~((results[\"mistral_correct\"] == \"Yes\") & (results[\"llama_correct\"] == \"Yes\"))\n",
    "    ]\n",
    "\n",
    "    results.to_latex(\n",
    "        f\"figs/paper_plots/{task_name}_results.tex\",\n",
    "        index=False,\n",
    "        caption=\"$\\\\texttt{Weekdays}$ finegrained results. Row ommited if both models get it correct.\",\n",
    "        header=[\n",
    "            \"$\\\\alpha$\",\n",
    "            \"$\\\\beta$\",\n",
    "            \"Ground truth $\\\\gamma$\",\n",
    "            \"Mistral top logit $\\\\gamma$\",\n",
    "            \"Mistral correct?\",\n",
    "            \"Llama top logit $\\\\gamma$\",\n",
    "            \"Llama correct?\",\n",
    "        ],\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "multiplexing",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
