{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import transformers\n",
    "import torch\n",
    "\n",
    "import numllama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unfortunately, I was not careful enough to avoid custom types\n",
    "# in the lighning module __init__, which causes that saved\n",
    "# checkpoints have references to classes (that are not available\n",
    "# when moving code to a different module or project).\n",
    "# This hack redirects the references in the saved checkpoint\n",
    "# to correct classes. In the future, just avoid custom types\n",
    "# in the lightning module __init__ and use only built-in types\n",
    "# to make checkpoints portable.\n",
    "\n",
    "sys.modules[\"svgai\"] = numllama\n",
    "sys.modules[\"svgai.train\"] = numllama.addition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/tmp/xkadlci2/.micromamba/envs/numllama/lib/python3.12/site-packages/torch/nn/modules/transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "load_numeric_checkpoint = None\n",
    "load_numeric_checkpoint = '/home/xkadlci2/svgai/checkpoints/vocal-frost-603__c5a8xfde/global-step=240000__valid-acc=1.000.ckpt'\n",
    "\n",
    "if load_numeric_checkpoint:\n",
    "    addition_model = numllama.addition.AdditionLightning.load_from_checkpoint(load_numeric_checkpoint)\n",
    "    numeric_input_emb_config = addition_model.model.embedding_config.model_dump()\n",
    "    numeric_encoder_config = addition_model.model.num_encoder_config\n",
    "else:\n",
    "    addition_model = None\n",
    "    numeric_input_emb_config = dict(\n",
    "        embedding_dim=256,\n",
    "        min_value=0,\n",
    "        max_value=10000,\n",
    "        use_l2_norm=False,\n",
    "        norm_const=None,\n",
    "    )\n",
    "    numeric_encoder_config = dict(\n",
    "        _target_=\"numllama.nn.feedforward_backbone\",\n",
    "        model_dim=256,\n",
    "        ff_dim=128,\n",
    "        num_blocks=8,\n",
    "        normalization=None,\n",
    "        use_skips=True,\n",
    "        skips_are_learnable=False,\n",
    "        linears_constraint=None,\n",
    "        dropout=0,\n",
    "        activation_fn=dict(\n",
    "            _target_=\"torch.nn.GELU\"\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the pretrained llama model\n",
    "checkpoint_name = \"meta-llama/Llama-3.2-1B\"\n",
    "original_config = transformers.LlamaConfig.from_pretrained(checkpoint_name)\n",
    "config = numllama.NumLlamaConfig(\n",
    "    numeric_input_emb_config=numeric_input_emb_config,\n",
    "    numeric_encoder_config=numeric_encoder_config,\n",
    "    **original_config.to_dict(),\n",
    ")\n",
    "\n",
    "num_llama: numllama.NumLlamaForCausalLM\n",
    "num_llama = numllama.NumLlamaForCausalLM.from_pretrained(checkpoint_name, config=config)\n",
    "\n",
    "# create the new numeric embedding layer inside llama\n",
    "num_llama.apply_numeric_patch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('hello', 15339),\n",
       " ('Ġ', 220),\n",
       " ('0', 15),\n",
       " ('Ġ', 220),\n",
       " ('1', 16),\n",
       " ('Ġ', 220),\n",
       " ('2', 17),\n",
       " ('Ġ', 220),\n",
       " ('3', 18),\n",
       " ('Ġ', 220),\n",
       " ('250', 5154),\n",
       " ('Ġ', 220),\n",
       " ('640', 14033),\n",
       " ('1', 16),\n",
       " ('Ġ', 220),\n",
       " ('131', 9263),\n",
       " ('070', 17819),\n",
       " ('Ġ', 220),\n",
       " ('131', 9263),\n",
       " ('071', 24508)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check original tokenization\n",
    "test_string = \"hello 0 1 2 3 250 6401 131070 131071\"\n",
    "list(zip(tokenizer.tokenize(test_string), tokenizer.encode(test_string, add_special_tokens=False), strict=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xkadlci2/numllama/numllama/numllama.py:103: UserWarning: This function assumes the specific behavior of llama 3 1b tokenizer. Make sure to verify it works as expected on your particular tokenizer.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# change how llama tokenizes numbers\n",
    "numllama.patch_llama_digit_splitting(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xkadlci2/numllama/numllama/numllama.py:128: UserWarning: This function assumes the specific behavior of llama 3 1b tokenizer. Make sure to verify it works as expected on your particular tokenizer.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "numllama.add_num_tokens_to_tokenizer(\n",
    "    numeric_input_emb_config[\"min_value\"],\n",
    "    numeric_input_emb_config[\"max_value\"],\n",
    "    tokenizer,\n",
    "    num_llama\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('hello', 15339),\n",
       " (' 0', 128256),\n",
       " (' 1', 128257),\n",
       " (' 2', 128258),\n",
       " (' 3', 128259),\n",
       " (' 250', 128506),\n",
       " (' 6401', 134657),\n",
       " (' 131070', 259326),\n",
       " (' 131071', 259327)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# checking the new tokenization\n",
    "list(zip(tokenizer.tokenize(test_string), tokenizer.encode(test_string, add_special_tokens=False), strict=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load trained weights from the pretrained addition model\n",
    "\n",
    "if addition_model is not None:\n",
    "    num_state_dict = addition_model.model.embedding.state_dict()\n",
    "    num_llama.get_numeric_emb().load_state_dict(num_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[\"<|begin_of_text|>Hello, it's me. 35293 10674 63234 7739 36662\",\n",
       " '<|begin_of_text|>Hello, 2 is what? 63214 7750 35282 63447 63155']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Trying out generation\n",
    "\n",
    "input_str = [\n",
    "    \"Hello, it's me.\",\n",
    "    \"Hello, 2 is what?\"\n",
    "]\n",
    "\n",
    "inputs = tokenizer(input_str, return_tensors=\"pt\", truncation=True)\n",
    "\n",
    "with num_llama.build_num_latents():\n",
    "    outputs = num_llama.generate(**inputs.to(num_llama.device), max_new_tokens=5)\n",
    "\n",
    "tokenizer.batch_decode(outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
