{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import os\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b\"\n",
    "# MODEL_PATH = \"EleutherAI/pythia-2.8b-deduped\"\n",
    "\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###################################################\n",
    "# prompt_template = mt.tokenizer.eos_token + \" {} professionally played the sport of\"\n",
    "# subject = \"Michael Jordan\"\n",
    "# alt_subject = \"Pele\"\n",
    "\n",
    "# prompt_template = mt.tokenizer.eos_token + \" {} is located in the city of\"\n",
    "# subject = \"Harvard University\"\n",
    "# alt_subject = \"Oxford University\"\n",
    "\n",
    "# prompt_template = mt.tokenizer.eos_token + \" The headquarter of {} is located in\"\n",
    "# subject = \"Microsoft\"\n",
    "# alt_subject = \"Google\"\n",
    "\n",
    "# prompt_template = mt.tokenizer.eos_token + \" The capital of {} is the city of\"\n",
    "# subject = \"Canada\"\n",
    "# alt_subject = \"Japan\"\n",
    "\n",
    "# prompt_template = mt.tokenizer.eos_token + \" {} originated in the country of\"\n",
    "# subject = \"Pizza\"\n",
    "# alt_subject = \"Sushi\"\n",
    "\n",
    "prompt_template = mt.tokenizer.eos_token + \" By profession {} was a\"\n",
    "subject = \"Bruce Lee\"\n",
    "alt_subject = \"Harper Lee\"\n",
    "\n",
    "###################################################\n",
    "\n",
    "FILE_NAME_PREFIX = f\"{subject.replace(' ', '_')}\"\n",
    "SAVE_DIR = f\"../Figures/causal_tracing_examples\"\n",
    "os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "\n",
    "prompt = prompt_template.format(subject)\n",
    "\n",
    "\n",
    "from src.functional import predict_next_token\n",
    "predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = [\n",
    "        prompt_template.format(subject) for subject in [subject, alt_subject]\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tracing import calculate_hidden_flow\n",
    "from src.plotting import plot_trace_heatmap\n",
    "\n",
    "indirect_effects = calculate_hidden_flow(\n",
    "    mt = mt,\n",
    "    prompt = prompt_template,\n",
    "    subject = subject,\n",
    "    alt_subject = alt_subject\n",
    "    # subject = alt_subject,\n",
    "    # alt_subject = subject\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trace_heatmap(\n",
    "    indirect_effects, \n",
    "    modelname=MODEL_PATH.split(\"/\")[-1],\n",
    "    relative_recovery=True,\n",
    "    savepdf=os.path.join(SAVE_DIR, f\"{FILE_NAME_PREFIX}_res.pdf\"),\n",
    "    scale_range=(0, 0.5)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for mamba_block_hook in get_args(MambaBlock_Hook_Points):\n",
    "for mamba_block_hook in [\n",
    "    \"mlp_after_silu\",\n",
    "    \"after_down_proj\",\n",
    "    \"ssm_after_ssm\", \n",
    "]:\n",
    "    print(\"-\"*80)\n",
    "    print(f\"{mamba_block_hook=}\")\n",
    "    mt.reset_forward()\n",
    "    indirect_effects = calculate_hidden_flow(\n",
    "        mt = mt,\n",
    "        prompt = prompt_template,\n",
    "        subject = subject,\n",
    "        alt_subject = alt_subject,\n",
    "        mamba_block_hook=mamba_block_hook\n",
    "    )\n",
    "    plot_trace_heatmap(\n",
    "        indirect_effects, \n",
    "        modelname=MODEL_PATH.split(\"/\")[-1],\n",
    "        relative_recovery=True,\n",
    "        savepdf=os.path.join(SAVE_DIR, f\"{FILE_NAME_PREFIX}_{mamba_block_hook.split('_')[0]}.pdf\"),\n",
    "        scale_range=(0, 0.5)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot_trace_heatmap(\n",
    "#     indirect_effects, \n",
    "#     modelname=MODEL_PATH.split(\"/\")[-1],\n",
    "#     relative_recovery=True,\n",
    "#     # savepdf=os.path.join(SAVE_DIR, f\"{FILE_NAME_PREFIX}_{mamba_block_hook.split('_')[0]}.pdf\"),\n",
    "#     # scale_range=(0, 0.5)\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.tracing import calculate_hidden_flow\n",
    "# from src.plotting import plot_trace_heatmap\n",
    "\n",
    "# prompt_template = mt.tokenizer.eos_token + \" {} is located in the city of\"\n",
    "\n",
    "# subject = \"The Space Needle\"\n",
    "# prompt = prompt_template.format(subject)\n",
    "\n",
    "# indirect_effects = calculate_hidden_flow(\n",
    "#     mt = mt,\n",
    "#     prompt = prompt,\n",
    "#     subject = subject,\n",
    "# )\n",
    "# plot_trace_heatmap(indirect_effects, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.hooking.mamba import MambaBlock_Hook_Points\n",
    "# from typing import get_args\n",
    "\n",
    "# for mamba_block_hook in get_args(MambaBlock_Hook_Points):\n",
    "#     print(\"-\"*80)\n",
    "#     print(f\"{mamba_block_hook=}\")\n",
    "#     mt.reset_forward()\n",
    "#     indirect_effects = calculate_hidden_flow(\n",
    "#         mt = mt,\n",
    "#         prompt = prompt,\n",
    "#         subject = subject,\n",
    "#         mamba_block_hook=mamba_block_hook\n",
    "#     )\n",
    "#     plot_trace_heatmap(indirect_effects, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for kind in [\"mlp\", \"attn\"]:\n",
    "#     print(\"-\"*80)\n",
    "#     print(f\"{kind=}\")\n",
    "#     mt.reset_forward()\n",
    "#     indirect_effects = calculate_hidden_flow(\n",
    "#         mt = mt,\n",
    "#         prompt = prompt,\n",
    "#         subject = subject,\n",
    "#         alt_subject=alt_subject,\n",
    "#         kind=kind\n",
    "#     )\n",
    "#     plot_trace_heatmap(indirect_effects, modelname=MODEL_PATH.split(\"/\")[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
