{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# modules\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as  nn\n",
    "\n",
    "import json\n",
    "import os\n",
    "from collections import namedtuple\n",
    "from dataclasses import field, dataclass\n",
    "from functools import partial\n",
    "\n",
    "from mamba_ssm.models.config_mamba import MambaConfig\n",
    "from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn\n",
    "from mamba_ssm.models.mixer_seq_simple import _init_weights, MixerModel\n",
    "from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\n",
    "from mamba_ssm.utils.generation import *\n",
    "from transformers import PretrainedConfig\n",
    "\n",
    "@dataclass\n",
    "class MambaConfig(PretrainedConfig):\n",
    "    d_model: int = 2560\n",
    "    n_layer: int = 64\n",
    "    vocab_size: int = 50277\n",
    "    ssm_cfg: dict = field(default_factory=dict)\n",
    "    rms_norm: bool = True\n",
    "    residual_in_fp32: bool = True\n",
    "    fused_add_norm: bool = True\n",
    "    pad_vocab_size_multiple: int = 8\n",
    "    max_position_embeddings: int = 2048"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def sample_safe(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):\n",
    "    \"\"\"Sample from top-k logits.\n",
    "    Arguments:\n",
    "        logits: Tensor of shape (batch_size, vocab_size)\n",
    "    \"\"\"\n",
    "    if top_k == 1:  # Short-circuit for greedy decoding\n",
    "        return logits.argmax(dim=-1)\n",
    "    else:\n",
    "        if top_p > 0.0:\n",
    "            assert top_p <= 1.0, \"top-p should be in (0, 1].\"\n",
    "        if top_k > 0:\n",
    "            top_k = min(top_k, logits.size(-1))  # Safety check\n",
    "            logits_top, indices = torch.topk(logits, top_k, dim=-1)\n",
    "            if temperature != 1.0:\n",
    "                logits_top /= temperature\n",
    "            modify_logits_for_top_p_filtering(logits_top, top_p)\n",
    "            return indices[\n",
    "                torch.arange(indices.shape[0], device=indices.device),\n",
    "                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),\n",
    "            ]\n",
    "        else:\n",
    "            if min_p > 0.0:\n",
    "                logits_top = logits.clone()\n",
    "                max_prob = logits_top[..., 0].item()\n",
    "                min_prob = max_prob * min_p\n",
    "                modify_logits_for_min_p_filtering(logits_top, min_p)\n",
    "                if temperature != 1.0:\n",
    "                    logits_top /= temperature\n",
    "                return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)\n",
    "            # Clone so that when we modify for top_p we don't change the original logits\n",
    "            logits_top = logits / temperature if temperature != 1.0 else logits.clone()\n",
    "            modify_logits_for_top_p_filtering(logits_top, top_p)\n",
    "            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(\n",
    "                dim=-1\n",
    "            )\n",
    "\n",
    "@torch.inference_mode()\n",
    "def decode_safe(\n",
    "    input_ids,\n",
    "    position_ids,\n",
    "    seq_position_ids,\n",
    "    is_fim,\n",
    "    model,\n",
    "    max_length,\n",
    "    top_k=1,\n",
    "    top_p=0.0,\n",
    "    min_p=0.0,\n",
    "    temperature=1.0,\n",
    "    repetition_penalty=1.0,\n",
    "    eos_token_id=None,\n",
    "    teacher_outputs=None,\n",
    "    vocab_size=None,\n",
    "    cg=False,\n",
    "    enable_timing=False,\n",
    "    streamer: Optional[TextStreamer] = None\n",
    "):\n",
    "    \"\"\"Decoding, either greedy or with top-k or top-p sampling.\n",
    "    If top-k = 0, don't limit the number of candidates (pure sampling).\n",
    "    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,\n",
    "    then top-p.\n",
    "    We assume that all sequences in the same batch have the same length.\n",
    "\n",
    "    Arguments:\n",
    "        input_ids: (batch, seq_len)\n",
    "        max_length: int\n",
    "        is_fim: dictionary with mask indices and associated position indices\n",
    "        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the\n",
    "            logits, the next token is taken from the teacher_outputs. Useful for testing.\n",
    "    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:\n",
    "        sequences: (batch, max_length)\n",
    "        scores: tuples of (batch, vocab_size)\n",
    "    \"\"\"\n",
    "    if streamer is not None:\n",
    "        streamer.put(input_ids.cpu())\n",
    "\n",
    "    batch_size, seqlen_og = input_ids.shape\n",
    "    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0\n",
    "    if cg:\n",
    "        if not hasattr(model, \"_decoding_cache\"):\n",
    "            model._decoding_cache = None\n",
    "        model._decoding_cache = update_graph_cache(\n",
    "            model,\n",
    "            model._decoding_cache,\n",
    "            batch_size,\n",
    "            seqlen_og,\n",
    "            max_length,\n",
    "        )\n",
    "        inference_params = model._decoding_cache.inference_params\n",
    "        inference_params.reset(max_length, batch_size)\n",
    "    else:\n",
    "        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)\n",
    "\n",
    "    def get_logits(input_ids, position_ids, seq_position_ids, inference_params):\n",
    "        decoding = inference_params.seqlen_offset > 0\n",
    "        if not cg or not decoding:\n",
    "            logits = model(\n",
    "                input_ids,\n",
    "                position_ids=position_ids,\n",
    "                seq_position_ids=seq_position_ids,\n",
    "                inference_params=inference_params,\n",
    "                num_last_tokens=1,\n",
    "            ).logits.squeeze(dim=1)\n",
    "        else:\n",
    "            logits = model._decoding_cache.run(\n",
    "                input_ids, position_ids, inference_params.seqlen_offset, seq_position_ids=seq_position_ids\n",
    "            ).squeeze(dim=1)\n",
    "        return logits[..., :vocab_size] if vocab_size is not None else logits\n",
    "\n",
    "    def sample_tokens(logits, inference_params):\n",
    "        if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:\n",
    "            token = sample_safe(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)\n",
    "        else:\n",
    "            token = teacher_outputs[:, inference_params.seqlen_offset]\n",
    "        # return rearrange(token, \"b -> b 1\")\n",
    "        return token.unsqueeze(1)\n",
    "\n",
    "    def get_fim_position_id(last_position_ids, sampled_tokens, is_fim, repeat_next=False):\n",
    "        val = int(last_position_ids) + 1 \n",
    "        should_repeat_next = False\n",
    "        if is_fim and int(sampled_tokens) in is_fim:\n",
    "            val = is_fim[int(sampled_tokens)]\n",
    "            should_repeat_next = True\n",
    "        elif repeat_next:\n",
    "            val = int(last_position_ids)\n",
    "        return torch.full_like(last_position_ids, fill_value=val), should_repeat_next\n",
    "    \n",
    "    def should_stop(current_token, inference_params):\n",
    "        if inference_params.seqlen_offset == 0:\n",
    "            return False\n",
    "        if eos_token_id is not None and (current_token == eos_token_id).any():\n",
    "            if current_token.shape[1] > 1:\n",
    "                raise NotImplementedError(\"Batched eos_token_id not supported\")\n",
    "            return True\n",
    "        if inference_params.seqlen_offset >= max_length - 1:\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "    start = torch.cuda.Event(enable_timing=enable_timing)\n",
    "    end = torch.cuda.Event(enable_timing=enable_timing)\n",
    "\n",
    "    if enable_timing:\n",
    "        start.record()\n",
    "    scores, sequences = [], [input_ids]\n",
    "    new_position_ids, new_seq_position_ids = [position_ids], [seq_position_ids]\n",
    "    sequences_cat = input_ids\n",
    "    repeat_next = False\n",
    "    if position_ids.shape[0]>1:\n",
    "        raise NotImplementedError(\"Batched generation with position_ids not supported\")\n",
    "    while not should_stop(sequences[-1], inference_params):\n",
    "        scores.append(get_logits(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], inference_params))\n",
    "        inference_params.seqlen_offset += sequences[-1].shape[1]\n",
    "        if repetition_penalty == 1.0:\n",
    "            sampled_tokens = sample_tokens(scores[-1], inference_params)\n",
    "        else:\n",
    "            logits = modify_logit_for_repetition_penalty(\n",
    "                scores[-1].clone(), sequences_cat, repetition_penalty\n",
    "            )\n",
    "            sampled_tokens = sample_tokens(logits, inference_params)\n",
    "            sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)\n",
    "        sequences.append(sampled_tokens)\n",
    "        # Update position_ids\n",
    "        if position_ids is not None:\n",
    "            last_position_ids, repeat_next = get_fim_position_id(new_position_ids[-1][:,-1:], sampled_tokens, is_fim, repeat_next)\n",
    "            new_position_ids.append(last_position_ids)\n",
    "        # Update seq_position_ids\n",
    "        if seq_position_ids is not None:\n",
    "            new_seq_position_ids.append(new_seq_position_ids[-1][:,-1:])\n",
    "\n",
    "        if streamer is not None:\n",
    "            streamer.put(sampled_tokens.cpu())\n",
    "    if streamer is not None:\n",
    "        streamer.end()\n",
    "    if enable_timing:\n",
    "        end.record()\n",
    "        torch.cuda.synchronize()\n",
    "        print(f\"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms\")\n",
    "    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput\n",
    "    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))\n",
    "\n",
    "class GenerationMixinSafe(GenerationMixin):\n",
    "    \n",
    "    def generate(\n",
    "        self,\n",
    "        input_ids,\n",
    "        position_ids,\n",
    "        seq_position_ids,\n",
    "        is_fim=None,\n",
    "        max_length=1,\n",
    "        top_k=1,\n",
    "        top_p=0.0,\n",
    "        min_p=0.0,\n",
    "        temperature=1.0,\n",
    "        return_dict_in_generate=False,\n",
    "        output_scores=False,\n",
    "        **kwargs):\n",
    "        \n",
    "        output = decode_safe(\n",
    "            input_ids, position_ids, seq_position_ids, is_fim, self, max_length, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature, **kwargs\n",
    "        )\n",
    "        if not output_scores:\n",
    "            output.scores = None\n",
    "        return output if return_dict_in_generate else output.sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "from torch.utils.checkpoint import checkpoint\n",
    "from mamba_ssm.modules.mamba_simple import Mamba, Block\n",
    "\n",
    "\n",
    "class CheckpointedModule(torch.nn.Module):\n",
    "    def __init__(self, layer):\n",
    "        super().__init__()\n",
    "        self.ckpt_layer = layer\n",
    "\n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        return checkpoint(self.ckpt_layer, x, use_reentrant=False)\n",
    "\n",
    "    # def state_dict(self, **kwargs):\n",
    "    #     # Get the state dict of the underlying layer\n",
    "    #     layer_state_dict = self.ckpt_layer.state_dict(**kwargs)\n",
    "    #     # Create a new state dict with the original keys\n",
    "    #     state_dict = {k.replace('ckpt_layer.', ''): v for k, v in layer_state_dict.items()}\n",
    "    #     return state_dict\n",
    "\n",
    "def create_block(\n",
    "    d_model,\n",
    "    ssm_cfg=None,\n",
    "    norm_epsilon=1e-5,\n",
    "    rms_norm=False,\n",
    "    residual_in_fp32=False,\n",
    "    fused_add_norm=False,\n",
    "    layer_idx=None,\n",
    "    device=None,\n",
    "    dtype=None,\n",
    "    checkpoint_mixer=False,\n",
    "):\n",
    "    if ssm_cfg is None:\n",
    "        ssm_cfg = {}\n",
    "    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)\n",
    "    norm_cls = partial(\n",
    "        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs\n",
    "    )\n",
    "    block = Block(\n",
    "        d_model,\n",
    "        mixer_cls,\n",
    "        norm_cls=norm_cls,\n",
    "        fused_add_norm=fused_add_norm,\n",
    "        residual_in_fp32=residual_in_fp32,\n",
    "    )\n",
    "    block.layer_idx = layer_idx\n",
    "    if checkpoint_mixer:\n",
    "        block.mixer = CheckpointedModule(block.mixer)\n",
    "    return block"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plain Mamba model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "    \n",
    "class MixerModelSafe(MixerModel):\n",
    "    \"\"\"\n",
    "        Overwrite the forward method to allow saving intermediate layers.\n",
    "    \"\"\"\n",
    "    \n",
    "    def forward(self, input_ids, inference_params=None, save_layer=[]):\n",
    "        hidden_states = self.embedding(input_ids)\n",
    "        residual = None\n",
    "        if len(save_layer) > 0:\n",
    "            hidden_states_dict = {}\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            hidden_states, residual = layer(\n",
    "                hidden_states, residual, inference_params=inference_params\n",
    "            )\n",
    "            if i+1 in save_layer:\n",
    "                hidden_states_dict[i+1] = hidden_states.detach().cpu().to(torch.float).numpy()\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states_dict\n",
    "            \n",
    "        if not self.fused_add_norm:\n",
    "            residual = (hidden_states + residual) if residual is not None else hidden_states\n",
    "            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n",
    "        else:\n",
    "            # Set prenorm=False here since we don't need the residual\n",
    "            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn\n",
    "            hidden_states = fused_add_norm_fn(\n",
    "                hidden_states,\n",
    "                self.norm_f.weight,\n",
    "                self.norm_f.bias,\n",
    "                eps=self.norm_f.eps,\n",
    "                residual=residual,\n",
    "                prenorm=False,\n",
    "                residual_in_fp32=self.residual_in_fp32,\n",
    "            )\n",
    "        return hidden_states\n",
    "\n",
    "class MambaLMHeadModelSafe(nn.Module, GenerationMixinSafe):\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            config: MambaConfig,\n",
    "            initializer_cfg=None,\n",
    "            device=None,\n",
    "            dtype=None,\n",
    "            checkpoint_mixer=False,\n",
    "    ) -> None:\n",
    "        self.config = config\n",
    "        d_model = config.d_model\n",
    "        n_layer = config.n_layer\n",
    "        vocab_size = config.vocab_size\n",
    "        ssm_cfg = config.ssm_cfg\n",
    "        rms_norm = config.rms_norm\n",
    "        residual_in_fp32 = config.residual_in_fp32\n",
    "        fused_add_norm = config.fused_add_norm\n",
    "        pad_vocab_size_multiple = config.pad_vocab_size_multiple\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        if checkpoint_mixer:\n",
    "            raise NotImplementedError(\"Checkpointing is not yet supported for MambaLMHeadModelSafe\")\n",
    "\n",
    "        super().__init__()\n",
    "        if vocab_size % pad_vocab_size_multiple != 0:\n",
    "            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n",
    "        self.backbone = MixerModelSafe(\n",
    "            d_model=d_model,\n",
    "            n_layer=n_layer,\n",
    "            vocab_size=vocab_size,\n",
    "            ssm_cfg=ssm_cfg,\n",
    "            rms_norm=rms_norm,\n",
    "            initializer_cfg=initializer_cfg,\n",
    "            fused_add_norm=fused_add_norm,\n",
    "            residual_in_fp32=residual_in_fp32,\n",
    "            **factory_kwargs,\n",
    "        )\n",
    "        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "        self.tie_weights()\n",
    "    \n",
    "    def tie_weights(self):\n",
    "        self.lm_head.weight = self.backbone.embedding.weight\n",
    "        \n",
    "    def clip_grad_norm_(self, max_norm, norm_type=2.0):\n",
    "        r\"\"\"Clip the norm of the gradients for the model.\n",
    "        Args:\n",
    "            max_norm (float or int): The maximum norm of the gradients.\n",
    "                The gradients are modified in-place.\n",
    "            norm_type (float or int): The type of the used p-norm. Can be 'inf' for infinity norm.\n",
    "        Returns:\n",
    "            Total norm of the parameters (viewed as a single vector).\n",
    "        \"\"\"\n",
    "        return torch.nn.utils.clip_grad_value_(self.parameters(), max_norm)\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "\n",
    "    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[], *args, **kwargs):\n",
    "        \"\"\"\n",
    "        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n",
    "        num_last_tokens: if > 0, only return the logits for the last n tokens\n",
    "        \"\"\"\n",
    "        return self.protected_forward(input_ids, position_ids, inference_params, num_last_tokens, save_layer)\n",
    "\n",
    "    def protected_forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[]):\n",
    "        hidden_states = self.backbone(input_ids, inference_params=inference_params, save_layer=save_layer)\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states\n",
    "        if num_last_tokens > 0:\n",
    "            hidden_states = hidden_states[:, -num_last_tokens:]\n",
    "        lm_logits = self.lm_head(hidden_states)\n",
    "        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"loss\", \"logits\"])\n",
    "        return CausalLMOutput(loss=None, logits=lm_logits)\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):\n",
    "        config_data = load_config_hf(pretrained_model_name)\n",
    "        config = MambaConfig(**config_data)\n",
    "        model = cls(config, device=device, dtype=dtype, **kwargs)\n",
    "        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype), strict=False)\n",
    "        return model\n",
    "\n",
    "    def save_pretrained(self, save_directory):\n",
    "        \"\"\"\n",
    "        Minimal implementation of save_pretrained for MambaLMHeadModel.\n",
    "        Save the model and its configuration file to a directory.\n",
    "        \"\"\"\n",
    "        # Ensure save_directory exists\n",
    "        if not os.path.exists(save_directory):\n",
    "            os.makedirs(save_directory)\n",
    "\n",
    "        # Save the model's state_dict\n",
    "        model_path = os.path.join(save_directory, 'pytorch_model.bin')\n",
    "        torch.save(self.state_dict(), model_path)\n",
    "\n",
    "        # Save the configuration of the model\n",
    "        config_path = os.path.join(save_directory, 'config.json')\n",
    "        with open(config_path, 'w') as f:\n",
    "            json.dump(self.config.__dict__, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mamba model with positional ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "class MixerModelWithPosids(nn.Module):\n",
    "    r\"\"\"Mixer model for Mamba but we add positional encodings to the input embeddings.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            d_model: int,\n",
    "            n_layer: int,\n",
    "            vocab_size: int,\n",
    "            max_position_embeddings: int,\n",
    "            ssm_cfg=None,\n",
    "            norm_epsilon: float = 1e-5,\n",
    "            rms_norm: bool = False,\n",
    "            initializer_cfg=None,\n",
    "            fused_add_norm=False,\n",
    "            residual_in_fp32=False,\n",
    "            device=None,\n",
    "            dtype=None,\n",
    "            checkpoint_mixer=False,\n",
    "    ) -> None:\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        super().__init__()\n",
    "        self.residual_in_fp32 = residual_in_fp32\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model // 2, **factory_kwargs)\n",
    "        self.position_embedding = nn.Embedding(max_position_embeddings, d_model - d_model // 2, **factory_kwargs)\n",
    "\n",
    "        # We change the order of residual and layer norm:\n",
    "        # Instead of LN -> Attn / MLP -> Add, we do:\n",
    "        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and\n",
    "        # the main branch (output of MLP / Mixer). The model definition is unchanged.\n",
    "        # This is for performance reason: we can fuse add + layer_norm.\n",
    "        self.fused_add_norm = fused_add_norm\n",
    "        if self.fused_add_norm:\n",
    "            if layer_norm_fn is None or rms_norm_fn is None:\n",
    "                raise ImportError(\"Failed to import Triton LayerNorm / RMSNorm kernels\")\n",
    "\n",
    "        self.layers = nn.ModuleList(\n",
    "            [\n",
    "                create_block(\n",
    "                    d_model,\n",
    "                    ssm_cfg=ssm_cfg,\n",
    "                    norm_epsilon=norm_epsilon,\n",
    "                    rms_norm=rms_norm,\n",
    "                    residual_in_fp32=residual_in_fp32,\n",
    "                    fused_add_norm=fused_add_norm,\n",
    "                    layer_idx=i,\n",
    "                    checkpoint_mixer=checkpoint_mixer,\n",
    "                    **factory_kwargs,\n",
    "                )\n",
    "                for i in range(n_layer)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(\n",
    "            d_model, eps=norm_epsilon, **factory_kwargs\n",
    "        )\n",
    "\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return {\n",
    "            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "            for i, layer in enumerate(self.layers)\n",
    "        }\n",
    "\n",
    "    def forward(self, input_ids, position_ids, inference_params=None, save_layer=[]):\n",
    "        hidden_states = torch.cat([self.embedding(input_ids), self.position_embedding(position_ids), ], -1)\n",
    "        residual = None\n",
    "        if len(save_layer) > 0:\n",
    "            hidden_states_dict = {}\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            hidden_states, residual = layer(\n",
    "                hidden_states, residual, inference_params=inference_params\n",
    "            )\n",
    "            if i+1 in save_layer:\n",
    "                hidden_states_dict[i+1] = hidden_states.detach().cpu().to(torch.float).numpy()\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states_dict\n",
    "            \n",
    "        if not self.fused_add_norm:\n",
    "            residual = (hidden_states + residual) if residual is not None else hidden_states\n",
    "            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n",
    "        else:\n",
    "            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn\n",
    "            hidden_states = fused_add_norm_fn(\n",
    "                hidden_states,\n",
    "                self.norm_f.weight,\n",
    "                self.norm_f.bias,\n",
    "                eps=self.norm_f.eps,\n",
    "                residual=residual,\n",
    "                prenorm=False,\n",
    "                residual_in_fp32=self.residual_in_fp32,\n",
    "            )\n",
    "        return hidden_states\n",
    "\n",
    "\n",
    "class MixerModelWith2DPosids(nn.Module):\n",
    "    r\"\"\"Mixer model for Mamba but we add positional encodings to the input embeddings.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            d_model: int,\n",
    "            n_layer: int,\n",
    "            vocab_size: int,\n",
    "            max_position_embeddings: int,\n",
    "            max_sequence_position_embeddings: int = 512,\n",
    "            ssm_cfg=None,\n",
    "            norm_epsilon: float = 1e-5,\n",
    "            rms_norm: bool = False,\n",
    "            initializer_cfg=None,\n",
    "            fused_add_norm=False,\n",
    "            residual_in_fp32=False,\n",
    "            device=None,\n",
    "            dtype=None,\n",
    "            checkpoint_mixer=False,\n",
    "    ) -> None:\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "        super().__init__()\n",
    "        self.residual_in_fp32 = residual_in_fp32\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model - 2 * d_model // 4, **factory_kwargs)\n",
    "        self.position_embedding = nn.Embedding(max_position_embeddings, d_model // 4, **factory_kwargs)\n",
    "        self.seq_position_embedding = nn.Embedding(max_sequence_position_embeddings, d_model // 4, **factory_kwargs)\n",
    "        self.d_embeddings = d_model - 2 * d_model // 4\n",
    "\n",
    "        # We change the order of residual and layer norm:\n",
    "        # Instead of LN -> Attn / MLP -> Add, we do:\n",
    "        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and\n",
    "        # the main branch (output of MLP / Mixer). The model definition is unchanged.\n",
    "        # This is for performance reason: we can fuse add + layer_norm.\n",
    "        self.fused_add_norm = fused_add_norm\n",
    "        if self.fused_add_norm:\n",
    "            if layer_norm_fn is None or rms_norm_fn is None:\n",
    "                raise ImportError(\"Failed to import Triton LayerNorm / RMSNorm kernels\")\n",
    "\n",
    "        self.layers = nn.ModuleList(\n",
    "            [\n",
    "                create_block(\n",
    "                    d_model,\n",
    "                    ssm_cfg=ssm_cfg,\n",
    "                    norm_epsilon=norm_epsilon,\n",
    "                    rms_norm=rms_norm,\n",
    "                    residual_in_fp32=residual_in_fp32,\n",
    "                    fused_add_norm=fused_add_norm,\n",
    "                    layer_idx=i,\n",
    "                    checkpoint_mixer=checkpoint_mixer,\n",
    "                    **factory_kwargs,\n",
    "                )\n",
    "                for i in range(n_layer)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(\n",
    "            d_model, eps=norm_epsilon, **factory_kwargs\n",
    "        )\n",
    "\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return {\n",
    "            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "            for i, layer in enumerate(self.layers)\n",
    "        }\n",
    "\n",
    "    def forward(self, input_ids, position_ids, seq_position_ids, inference_params=None, save_layer=[]):\n",
    "        hidden_states = torch.cat([self.embedding(input_ids), self.position_embedding(position_ids), self.seq_position_embedding(seq_position_ids), ], -1)\n",
    "        residual = None\n",
    "        if len(save_layer) > 0:\n",
    "            hidden_states_dict = {}\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            hidden_states, residual = layer(\n",
    "                hidden_states, residual, inference_params=inference_params\n",
    "            )\n",
    "            if i+1 in save_layer:\n",
    "                hidden_states_dict[i+1] = hidden_states.detach().cpu().to(torch.float).numpy()\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states_dict\n",
    "        \n",
    "        if not self.fused_add_norm:\n",
    "            residual = (hidden_states + residual) if residual is not None else hidden_states\n",
    "            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n",
    "        else:\n",
    "            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn\n",
    "            hidden_states = fused_add_norm_fn(\n",
    "                hidden_states,\n",
    "                self.norm_f.weight,\n",
    "                self.norm_f.bias,\n",
    "                eps=self.norm_f.eps,\n",
    "                residual=residual,\n",
    "                prenorm=False,\n",
    "                residual_in_fp32=self.residual_in_fp32,\n",
    "            )\n",
    "        return hidden_states\n",
    "\n",
    "\n",
    "\n",
    "class MambaLMHeadModelwithPosids(nn.Module, GenerationMixinSafe):\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            config: MambaConfig,\n",
    "            initializer_cfg=None,\n",
    "            device=None,\n",
    "            dtype=None,\n",
    "            checkpoint_mixer=False,\n",
    "    ) -> None:\n",
    "        self.config = config\n",
    "        d_model = config.d_model\n",
    "        n_layer = config.n_layer\n",
    "        vocab_size = config.vocab_size\n",
    "        max_position_embeddings = config.max_position_embeddings\n",
    "        ssm_cfg = config.ssm_cfg\n",
    "        rms_norm = config.rms_norm\n",
    "        residual_in_fp32 = config.residual_in_fp32\n",
    "        fused_add_norm = config.fused_add_norm\n",
    "        pad_vocab_size_multiple = config.pad_vocab_size_multiple\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "\n",
    "        super().__init__()\n",
    "        if vocab_size % pad_vocab_size_multiple != 0:\n",
    "            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n",
    "        self.backbone = MixerModelWithPosids(\n",
    "            d_model=d_model,\n",
    "            n_layer=n_layer,\n",
    "            vocab_size=vocab_size,\n",
    "            max_position_embeddings=max_position_embeddings,\n",
    "            ssm_cfg=ssm_cfg,\n",
    "            rms_norm=rms_norm,\n",
    "            initializer_cfg=initializer_cfg,\n",
    "            fused_add_norm=fused_add_norm,\n",
    "            residual_in_fp32=residual_in_fp32,\n",
    "            checkpoint_mixer=checkpoint_mixer,\n",
    "            **factory_kwargs,\n",
    "        )\n",
    "        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "        self.tie_weights()\n",
    "\n",
    "    def tie_weights(self):\n",
    "        self.lm_head.weight = self.backbone.embedding.weight\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "\n",
    "    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[], *args, **kwargs):\n",
    "        \"\"\"\n",
    "        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n",
    "        num_last_tokens: if > 0, only return the logits for the last n tokens\n",
    "        \"\"\"\n",
    "        return self.protected_forward(input_ids, position_ids, inference_params, num_last_tokens, save_layer)\n",
    "\n",
    "    def protected_forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[]):\n",
    "        hidden_states = self.backbone(input_ids, position_ids=position_ids, inference_params=inference_params, save_layer=save_layer)\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states\n",
    "        hidden_states = hidden_states[:, :, :self.config.d_model // 2]\n",
    "        if num_last_tokens > 0:\n",
    "            hidden_states = hidden_states[:, -num_last_tokens:]\n",
    "        lm_logits = self.lm_head(hidden_states)\n",
    "        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"loss\", \"logits\"])\n",
    "        return CausalLMOutput(loss=None, logits=lm_logits)\n",
    "    \n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, checkpoint_mixer=False, **kwargs):\n",
    "        config_data = load_config_hf(pretrained_model_name)\n",
    "        config = MambaConfig(**config_data)\n",
    "        model = cls(config, device=device, dtype=dtype, checkpoint_mixer=checkpoint_mixer, **kwargs)\n",
    "        state_dict = load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)\n",
    "        if state_dict.keys() != model.state_dict().keys():\n",
    "            if checkpoint_mixer:\n",
    "                for key in model.state_dict().keys():\n",
    "                    if \"ckpt_layer\" in key:\n",
    "                        state_dict[key] = state_dict.pop(key.replace(\"ckpt_layer.\", \"\"))\n",
    "                print(\"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys.\")\n",
    "            else:\n",
    "                for key in list(state_dict.keys()):\n",
    "                    if \"ckpt_layer\" in key:\n",
    "                        state_dict[key.replace(\"ckpt_layer.\", \"\")] = state_dict.pop(key)\n",
    "                print(\"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys.\")\n",
    "            assert state_dict.keys() == model.state_dict().keys(), \"The keys of the state_dict do not match the model's keys.\"\n",
    "        model.load_state_dict(state_dict)\n",
    "        return model\n",
    "\n",
    "    def save_pretrained(self, save_directory):\n",
    "        \"\"\"\n",
    "        Minimal implementation of save_pretrained for MambaLMHeadModel.\n",
    "        Save the model and its configuration file to a directory.\n",
    "        \"\"\"\n",
    "        # Ensure save_directory exists\n",
    "        if not os.path.exists(save_directory):\n",
    "            os.makedirs(save_directory)\n",
    "\n",
    "        # Save the model's state_dict\n",
    "        model_path = os.path.join(save_directory, 'pytorch_model.bin')\n",
    "        torch.save(self.state_dict(), model_path)\n",
    "\n",
    "        # Save the configuration of the model\n",
    "        config_path = os.path.join(save_directory, 'config.json')\n",
    "        with open(config_path, 'w') as f:\n",
    "            json.dump(self.config.__dict__, f)\n",
    "\n",
    "\n",
    "class MambaLMHeadModelwith2DPosids(nn.Module, GenerationMixinSafe):\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            config: MambaConfig,\n",
    "            initializer_cfg=None,\n",
    "            device=None,\n",
    "            dtype=None,\n",
    "            checkpoint_mixer=False,\n",
    "    ) -> None:\n",
    "        self.config = config\n",
    "        d_model = config.d_model\n",
    "        n_layer = config.n_layer\n",
    "        vocab_size = config.vocab_size\n",
    "        max_position_embeddings = config.max_position_embeddings\n",
    "        ssm_cfg = config.ssm_cfg\n",
    "        rms_norm = config.rms_norm\n",
    "        residual_in_fp32 = config.residual_in_fp32\n",
    "        fused_add_norm = config.fused_add_norm\n",
    "        pad_vocab_size_multiple = config.pad_vocab_size_multiple\n",
    "        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n",
    "\n",
    "        super().__init__()\n",
    "        if vocab_size % pad_vocab_size_multiple != 0:\n",
    "            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n",
    "        self.backbone = MixerModelWith2DPosids(\n",
    "            d_model=d_model,\n",
    "            n_layer=n_layer,\n",
    "            vocab_size=vocab_size,\n",
    "            max_position_embeddings=max_position_embeddings,\n",
    "            ssm_cfg=ssm_cfg,\n",
    "            rms_norm=rms_norm,\n",
    "            initializer_cfg=initializer_cfg,\n",
    "            fused_add_norm=fused_add_norm,\n",
    "            residual_in_fp32=residual_in_fp32,\n",
    "            checkpoint_mixer=checkpoint_mixer,\n",
    "            **factory_kwargs,\n",
    "        )\n",
    "        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.apply(\n",
    "            partial(\n",
    "                _init_weights,\n",
    "                n_layer=n_layer,\n",
    "                **(initializer_cfg if initializer_cfg is not None else {}),\n",
    "            )\n",
    "        )\n",
    "        self.tie_weights()\n",
    "\n",
    "    def tie_weights(self):\n",
    "        self.lm_head.weight = self.backbone.embedding.weight\n",
    "\n",
    "    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n",
    "        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n",
    "\n",
    "    def forward(self, input_ids, position_ids=None, seq_position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[], *args, **kwargs):\n",
    "        \"\"\"\n",
    "        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n",
    "        num_last_tokens: if > 0, only return the logits for the last n tokens\n",
    "        \"\"\"\n",
    "        return self.protected_forward(input_ids, position_ids, seq_position_ids, inference_params, num_last_tokens, save_layer)\n",
    "\n",
    "    def protected_forward(self, input_ids, position_ids=None, seq_position_ids=None, inference_params=None, num_last_tokens=0, save_layer=[]):\n",
    "        hidden_states = self.backbone(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids, inference_params=inference_params, save_layer=save_layer)\n",
    "        if len(save_layer) > 0:\n",
    "            return hidden_states\n",
    "        hidden_states = hidden_states[:, :, :self.backbone.d_embeddings]\n",
    "        if num_last_tokens > 0:\n",
    "            hidden_states = hidden_states[:, -num_last_tokens:]\n",
    "        lm_logits = self.lm_head(hidden_states)\n",
    "        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"loss\", \"logits\"])\n",
    "        return CausalLMOutput(loss=None, logits=lm_logits)\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, checkpoint_mixer=False, **kwargs):\n",
    "        config_data = load_config_hf(pretrained_model_name)\n",
    "        config = MambaConfig(**config_data)\n",
    "        model = cls(config, device=device, dtype=dtype, checkpoint_mixer=checkpoint_mixer, **kwargs)\n",
    "        state_dict = load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)\n",
    "        if state_dict.keys() != model.state_dict().keys():\n",
    "            if checkpoint_mixer:\n",
    "                for key in model.state_dict().keys():\n",
    "                    if \"ckpt_layer\" in key:\n",
    "                        state_dict[key] = state_dict.pop(key.replace(\"ckpt_layer.\", \"\"))\n",
    "                print(\"Using a model that was pretrained without gradient checkpointing and now want to use it. Changed the keys of the state_dict to match the model's keys.\")\n",
    "            else:\n",
    "                for key in list(state_dict.keys()):\n",
    "                    if \"ckpt_layer\" in key:\n",
    "                        state_dict[key.replace(\"ckpt_layer.\", \"\")] = state_dict.pop(key)\n",
    "                print(\"Using a model that was pretrained with gradient checkpointing but now do not want to use it. Changed the keys of the state_dict to match the model's keys.\")\n",
    "            assert state_dict.keys() == model.state_dict().keys(), \"The keys of the state_dict do not match the model's keys.\"\n",
    "        model.load_state_dict(state_dict)\n",
    "        return model\n",
    "\n",
    "    def save_pretrained(self, save_directory):\n",
    "        \"\"\"\n",
    "        Minimal implementation of save_pretrained for MambaLMHeadModel.\n",
    "        Save the model and its configuration file to a directory.\n",
    "        \"\"\"\n",
    "        # Ensure save_directory exists\n",
    "        if not os.path.exists(save_directory):\n",
    "            os.makedirs(save_directory)\n",
    "\n",
    "        # Save the model's state_dict\n",
    "        model_path = os.path.join(save_directory, 'pytorch_model.bin')\n",
    "        torch.save(self.state_dict(), model_path)\n",
    "\n",
    "        # Save the configuration of the model\n",
    "        config_path = os.path.join(save_directory, 'config.json')\n",
    "        with open(config_path, 'w') as f:\n",
    "            json.dump(self.config.__dict__, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "def load_model(model_path, device, model_class=MambaLMHeadModelSafe, dtype=torch.bfloat16, checkpoint_mixer=False):\n",
    "    model = model_class.from_pretrained(model_path, device=device, dtype=dtype, checkpoint_mixer=checkpoint_mixer)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
