{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import baukit\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "from src import functional\n",
    "import src.tokens as tokenization_utils\n",
    "\n",
    "torch.__version__, transformers.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "MODEL_PATH = \"state-spaces/mamba-2.8b\" # state-spaces/mamba-2.8b | state-spaces/mamba-2.8b-slimpj\n",
    "\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_path=MODEL_PATH, \n",
    "    torch_dtype=torch.float32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subject = \"The Space Needle\"\n",
    "prompt_template = \"{} is located in the city of\"\n",
    "\n",
    "\n",
    "prompt = tokenization_utils.maybe_prefix_eos(mt, prompt_template.format(subject))\n",
    "\n",
    "mt.reset_forward()\n",
    "functional.predict_next_token(\n",
    "    mt = mt, \n",
    "    prompt = prompt,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.hooking.mamba import selective_scan_with_mask\n",
    "\n",
    "tokenized = mt.tokenizer(prompt, return_tensors=\"pt\", padding=True, return_offsets_mapping=True).to(mt.device)\n",
    "offsets = tokenized.pop(\"offset_mapping\")\n",
    "\n",
    "subj_first, subj_last = functional.find_token_range(\n",
    "    string = prompt,\n",
    "    substring=subject,\n",
    "    offset_mapping=offsets[0],\n",
    ")\n",
    "\n",
    "subj_last -= 1\n",
    "last_idx = tokenized.input_ids.shape[1] - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import types\n",
    "\n",
    "mamba_block_format = \"layers.{}.mixer\"\n",
    "\n",
    "\n",
    "def mask_layers(model, layer_indices, mask_config):\n",
    "    masked_scan = lambda self, u, delta, A, B, C, D: selective_scan_with_mask(\n",
    "        self=self, u=u, delta=delta, A=A, B=B, C=C, D=D, mask=mask_config\n",
    "    )\n",
    "    for layer_idx in layer_indices:\n",
    "        block = baukit.get_module(model, mamba_block_format.format(layer_idx))\n",
    "        block.selective_scan = types.MethodType(masked_scan, block)\n",
    "\n",
    "# -----------------------------------------\n",
    "mt.reset_forward()\n",
    "# -----------------------------------------\n",
    "\n",
    "mask_everything = {\n",
    "    idx: torch.arange(0, idx + 1).tolist() for idx in range(last_idx + 1)\n",
    "}\n",
    "\n",
    "mask_layers(\n",
    "    model=mt.model,\n",
    "    # layer_indices = [20],\n",
    "    layer_indices=torch.arange(0, mt.n_layer).tolist(),\n",
    "    # mask_config = {last_idx: []}  # no mask\n",
    "    # mask_config = {last_idx: [subj_last]}   # last token cant see subject_last token\n",
    "    # mask_config = {last_idx: torch.arange(subj_first, subj_last + 1).tolist()} # last token cant see the entire subject\n",
    "    # mask_config={\n",
    "    #     target_idx: [subj_last]\n",
    "    #     for target_idx in range(subj_last + 1, last_idx + 1)\n",
    "    # },  # the entire query cant see subject_last\n",
    "    mask_config={\n",
    "        target_idx: torch.arange(subj_first, subj_last + 1).tolist()\n",
    "        for target_idx in range(subj_last + 1, last_idx + 1)\n",
    "    },  # the entire query cant see the entire subject\n",
    "    # mask_config = mask_everything\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mt.reset_forward()\n",
    "\n",
    "# ! just ablating the diagonal ssm isn't enough.\n",
    "# TODO: figure out how to ablate the shift-SSM or the conv as well\n",
    "# also, the \"attention\" visualization is wrong, because it doesn't take the Conv into account\n",
    "# technically, ssm doesn't pay attention on a particular token, it pays attention to the entire receptive field\n",
    "\n",
    "functional.predict_next_token(\n",
    "    mt = mt, \n",
    "    prompt = prompt,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import experiment_utils\n",
    "experiment_utils.set_seed(123456)\n",
    "\n",
    "u = torch.randn(1, 4, 5120)\n",
    "delta = torch.randn(1, 4, 5120)\n",
    "A = torch.randn(5120, 16)\n",
    "B = torch.randn(1, 4, 16)\n",
    "C = torch.randn(1, 4, 16)\n",
    "D = torch.randn(5120)\n",
    "\n",
    "output = selective_scan_with_mask(\n",
    "    self=None,\n",
    "    u=u,\n",
    "    delta=delta,\n",
    "    A=A,\n",
    "    B=B,\n",
    "    C=C,\n",
    "    D=D,\n",
    "    # mask = {0:[]}\n",
    "    mask={0: [0], 1: [0, 1], 2: [0, 1, 2], 3: [0, 1, 2, 3]},\n",
    ")\n",
    "\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_A = torch.randn(1, 4, 5120, 16)\n",
    "src_idx = 2\n",
    "target_idx = 2\n",
    "\n",
    "delta_A_src_to_target = torch.prod(delta_A[:, src_idx + 1 : target_idx + 1], dim=1)\n",
    "delta_A_src_to_target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_B_src = torch.randn(1, 5120, 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_AB_src = delta_A_src_to_target * delta_B_src\n",
    "\n",
    "torch.allclose(delta_B_src, delta_AB_src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block = baukit.get_module(\n",
    "    mt.model,\n",
    "    \"layers.10.mixer\",\n",
    ")\n",
    "\n",
    "block.conv1d.weight.shape, block.conv1d.bias.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = 16\n",
    "b = 1\n",
    "x = torch.randn(b, 5120, l).to(mt.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv = block.conv1d(x)[:, :, :l]\n",
    "y_conv.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convolve_one_filter(\n",
    "    x: torch.Tensor, \n",
    "    filter_weights: torch.Tensor, \n",
    "    filter_bias: torch.Tensor, \n",
    "    prefix_padding: int = block.conv1d.padding[0],\n",
    "):\n",
    "    x = torch.nn.functional.pad(x, (prefix_padding, 0))\n",
    "    y = torch.zeros_like(x)\n",
    "    print(f\"{x.shape=}, {filter_weights.shape=}, {filter_bias.shape=}, {y.shape=}, {prefix_padding=}\")\n",
    "    for idx in range(filter_weights.shape[0], x.shape[1]):\n",
    "        y[:, idx] = torch.sum(x[:, idx - filter_weights.shape[0] : idx] * filter_weights, dim=1) + filter_bias\n",
    "    return y[:, : x.shape[1] - prefix_padding]\n",
    "\n",
    "filter_idx = 20\n",
    "inp_channel = 20\n",
    "f_weights = block.conv1d.weight[filter_idx]\n",
    "f_bias = block.conv1d.bias[filter_idx]\n",
    "y_conv_manual = convolve_one_filter(x[:, inp_channel], f_weights, f_bias)\n",
    "\n",
    "x[:, 0].shape, y_conv_manual.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv_manual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv[:, inp_channel]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_manual(x, weight, bias, padding):\n",
    "    pad = (padding[0], 0) \n",
    "    x_padded = torch.nn.functional.pad(x, pad, 'constant', 0) \n",
    "    \n",
    "    batch_size, channels, width = x.shape\n",
    "    out_channels, in_channels_per_group, kernel_size = weight.shape\n",
    "    output_width = width - kernel_size + 1 + sum(pad)\n",
    "    y = torch.zeros((batch_size, out_channels, output_width))\n",
    "\n",
    "    for batch in range(batch_size):\n",
    "        # for channel in range(channels):\n",
    "        #     for i in range(output_width):\n",
    "        #         current_slice = x_padded[batch, channel, i:i+kernel_size]\n",
    "        #         y[batch, channel, i] = (current_slice * weight[channel]).sum() + bias[channel]\n",
    "        for i in range(output_width):\n",
    "            current_slice = x_padded[batch, :, i:i+kernel_size]\n",
    "            y[batch, :, i] = (current_slice * weight.squeeze(dim=1)).sum(dim=1) + bias\n",
    "\n",
    "    return y\n",
    "\n",
    "\n",
    "conv1d = block.conv1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l = 16\n",
    "b = 1\n",
    "x = torch.randn(b, 5120, l).to(mt.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv = conv1d(x)[:, :, :l]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv_manual = conv_manual(x, conv1d.weight, conv1d.bias, padding=conv1d.padding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv_manual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_conv.shape, y_conv_manual.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_padded = torch.nn.functional.pad(x, (conv1d.padding[0], 0), 'constant', 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "channel = 2\n",
    "batch_idx = 0\n",
    "x_slice = x_padded[batch_idx, channel, 1:5]\n",
    "(x_slice * conv1d.weight[channel]).sum() + conv1d.bias[channel]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_idx = 0\n",
    "x_slice = x_padded[batch_idx, :, 1:5]\n",
    "# (x_slice * conv1d.weight).sum(dim=1) + conv1d.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# x_slice.shape, conv1d.weight.shape\n",
    "\n",
    "(x_slice * conv1d.weight.squeeze(1)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_slice * conv1d.weight.squeeze(1)).sum(dim=1) + conv1d.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_padded.shape, x_slice.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "block.x_proj.weight.shape, block.dt_proj.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_x = torch.randn(1, 16, 5120).to(mt.device)\n",
    "block.x_proj(dummy_x).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.model.args.dt_rank + 16 * 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
