{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import sys\n",
    "import time\n",
    "from torch import Tensor\n",
    "from typing import List\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import seaborn as sns\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "sns.set_style('white')\n",
    "\n",
    "import torch\n",
    "from modeling.mamba2.modeling_mamba2_dao import Mamba2ForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "\n",
    "def get_long_prompt():\n",
    "    return \"My name is John, and I like eating donuts. \" * 2000\n",
    "\n",
    "\n",
    "def get_tensor_stats(t):\n",
    "    mean = torch.mean(t)\n",
    "    var = torch.var(t)\n",
    "    norm = torch.sqrt(torch.mean(t * t))  # L2-norm\n",
    "    median = torch.median(t)\n",
    "    # mn = torch.min(t)\n",
    "    # mx = torch.max(t)\n",
    "    return mean, var, norm, median\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def get_stats(\n",
    "    model: Mamba2ForCausalLM,\n",
    "    input_ids: Tensor,\n",
    "    chunk_size: int = 128,\n",
    "    n_layers: int = 48,\n",
    "):\n",
    "    print(f\"Getting stats on prompt with {input_ids.shape[0]} tokens...\")\n",
    "    cur_state = None\n",
    "    # print(input_ids)\n",
    "    ssm_stats = {i: [] for i in range(n_layers)}\n",
    "    conv_stats = {i: [] for i in range(n_layers)}\n",
    "    states = []\n",
    "    seqlen = len(input_ids)\n",
    "    print(f\"Seq len: {seqlen}\")\n",
    "    for i in tqdm(range(0, seqlen, chunk_size)):\n",
    "        this_inputs = input_ids[i : i + chunk_size].unsqueeze(0)\n",
    "        # print(f\"Chunk {i} - {i + chunk_size}\")\n",
    "        output = model(this_inputs, states=cur_state)\n",
    "        logits = output.logits\n",
    "        cur_state = output.states\n",
    "        states.append(\n",
    "            [tuple(cur_state[li][j].clone() for j in range(2)) for li in range(n_layers)]\n",
    "        )\n",
    "        for layer_idx in range(n_layers):\n",
    "            conv_state = cur_state[layer_idx][0]\n",
    "            ssm_state = cur_state[layer_idx][1]\n",
    "            conv_stats[layer_idx].append(get_tensor_stats(conv_state))\n",
    "            ssm_stats[layer_idx].append(get_tensor_stats(ssm_state))\n",
    "        # print(cur_state)\n",
    "        # breakpoint()\n",
    "    return states, conv_stats, ssm_stats\n",
    "\n",
    "\n",
    "def plot_stats(\n",
    "    stats,\n",
    "    layer_indices: List[int],\n",
    "    dst_path: Path,\n",
    "    chunk_size: int = 128,\n",
    "    train_len: int = 8192,\n",
    "    figsize: tuple = (6, 3),\n",
    "):\n",
    "    metric_names = [\"Mean\", \"Var\", \"Norm\", \"Median\"]\n",
    "    metric_indices = [0, 2]\n",
    "    fig, ax = plt.subplots(1, len(metric_indices), figsize=figsize)\n",
    "    n_chunks = len(stats[0])\n",
    "    for i, axi in enumerate(ax):\n",
    "        for layer_i in layer_indices:\n",
    "            metric_idx = metric_indices[i]\n",
    "            metrics = [stats[layer_i][k][metric_idx] for k in range(n_chunks)]\n",
    "            # print(metrics)\n",
    "            xs = [chunk_size * x for x in range(len(metrics))]\n",
    "            axi.plot(xs, metrics, label=f\"Layer {layer_i}\", alpha=0.6, linewidth=0.8)\n",
    "        axi.set_title(metric_names[metric_idx])\n",
    "        axi.axvline(x=train_len, color=\"r\", linestyle=\"--\")\n",
    "        # axi.legend()\n",
    "    plt.legend(loc='right', ncol=1, bbox_to_anchor=(1, 0.5))\n",
    "    plt.tight_layout()\n",
    "    print(f\"Saving to {dst_path}\")\n",
    "    plt.savefig(dst_path, dpi=300)\n",
    "\n",
    "\n",
    "def plot_heads(\n",
    "    states,\n",
    "    dst_dir: Path,\n",
    "    target_layer: int,\n",
    "    chunk_size: int = 128,\n",
    "    train_len: int = 8192,\n",
    "    file_ext: str = 'jpg',\n",
    "    figsize: tuple = (6, 3),\n",
    "):\n",
    "    \"\"\"\n",
    "    states: (C, L, 2, B, H, P, N)\n",
    "    \"\"\"\n",
    "    n_chunks: int = len(states)\n",
    "    # Get SSM states: (C, H, P, N)\n",
    "    layer_state = [\n",
    "        states[c][target_layer][1].squeeze(0).cpu().float() for c in range(n_chunks)\n",
    "    ]\n",
    "    print(layer_state[0].shape)\n",
    "    print(f\"# chunks: {n_chunks}\")\n",
    "    nheads: int = layer_state[0].shape[0]\n",
    "    print(f\"# heads: {nheads}\")\n",
    "    metric_names = [\"Mean\", \"Var\", \"Norm\", \"Median\"]\n",
    "    metric_indices = [0, 2]\n",
    "\n",
    "    head_stats = {}  # {0~48: [(metrics ...), ...]}\n",
    "    for h in range(nheads):\n",
    "        head_stats[h] = [get_tensor_stats(layer_state[c][h]) for c in range(n_chunks)]\n",
    "\n",
    "    head_chunk_size = 12\n",
    "    for lo in range(0, nheads, head_chunk_size):\n",
    "        fig, ax = plt.subplots(1, len(metric_indices), figsize=figsize)\n",
    "        dst_path = dst_dir / f\"{lo}-{lo + head_chunk_size}.{file_ext}\"\n",
    "        head_indices = list(range(lo, lo + head_chunk_size))\n",
    "        for i, axi in enumerate(ax):\n",
    "            metric_idx = metric_indices[i]\n",
    "            for h in head_indices:\n",
    "                metrics = [head_stats[h][c][metric_idx] for c in range(n_chunks)]\n",
    "                xs = [chunk_size * x for x in range(n_chunks)]\n",
    "                axi.plot(xs, metrics, label=f\"Head {h}\", alpha=0.6)\n",
    "            axi.set_title(metric_names[metric_idx])\n",
    "            axi.axvline(x=train_len, color=\"r\", linestyle=\"--\")\n",
    "            # axi.legend()\n",
    "        fig.legend(loc='right', ncol=1, bbox_to_anchor=(1, 0.5))\n",
    "        dst_dir.mkdir(exist_ok=True, parents=True)\n",
    "        print(f\"Saving to {dst_path}\")\n",
    "        # plt.tight_layout()\n",
    "        plt.savefig(dst_path, dpi=300)\n",
    "        plt.clf()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
