{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n",
      "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
      "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import sys\n",
    "import json\n",
    "from dataclasses import dataclass\n",
    "\n",
    "sys.path.append(\"../../../open_lm\")\n",
    "from open_lm.model import Transformer\n",
    "from open_lm.norms import RmsNorm\n",
    "\n",
    "device = \"cuda:0\"\n",
    "cfg = json.load(open(\"../model_configs/llama2_7b.json\"))\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Params:\n",
    "    dim: int\n",
    "    n_layers: int\n",
    "    n_heads: int\n",
    "    vocab_size: int\n",
    "    norm_eps: float\n",
    "    seq_len: int\n",
    "    post_embed_norm: bool\n",
    "    weight_tying: bool\n",
    "    norm_type: nn.Module = RmsNorm  # Make sure to use RmsNorm for LLaMA\n",
    "    apply_qk_norm: bool = False\n",
    "    positional_embedding_type: str = \"llama_rotary\"  # Make sure to set this for LLaMA\n",
    "    ffn_type: str = \"swiglu\"\n",
    "\n",
    "\n",
    "args = Params(\n",
    "    dim=cfg[\"hidden_dim\"],\n",
    "    n_layers=cfg[\"n_layers\"],\n",
    "    n_heads=cfg[\"n_heads\"],\n",
    "    seq_len=cfg[\"seq_len\"],\n",
    "    vocab_size=cfg[\"vocab_size\"],\n",
    "    post_embed_norm=cfg[\"post_embed_norm\"],\n",
    "    weight_tying=cfg[\"weight_tying\"],\n",
    "    norm_eps=1e-5,\n",
    ")\n",
    "\n",
    "model = Transformer(args)\n",
    "state_dict = torch.load(\"./LLAMA2/llama-2-7b/consolidated.00.converted.pth\")\n",
    "model.load_state_dict(state_dict, strict=True)\n",
    "model = model.eval().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n",
      "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
      "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'"
     ]
    }
   ],
   "source": [
    "sys.path.append(\"./LLAMA2/llama\")\n",
    "from llama.tokenizer import Tokenizer\n",
    "\n",
    "tokenizer = Tokenizer(\"./LLAMA2/tokenizer.model\")\n",
    "\n",
    "\n",
    "def sample_top_p(probs, p):\n",
    "    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n",
    "    probs_sum = torch.cumsum(probs_sort, dim=-1)\n",
    "    mask = probs_sum - probs_sort > p\n",
    "    probs_sort[mask] = 0.0\n",
    "    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n",
    "    next_token = torch.multinomial(probs_sort, num_samples=1)\n",
    "    next_token = torch.gather(probs_idx, -1, next_token)\n",
    "    return next_token\n",
    "\n",
    "\n",
    "def generate_top_p_language(prefix: str, temperature: float = 0.6, top_p: float = 0.9, max_len: int = 128):\n",
    "    input_tokens = tokenizer.encode(prefix, bos=True, eos=False)\n",
    "    tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)\n",
    "\n",
    "    for i in range(max_len):\n",
    "        with torch.no_grad():\n",
    "            logits, _, _ = model(tokens)\n",
    "        if temperature > 0:\n",
    "            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)\n",
    "            next_token = sample_top_p(probs, top_p)\n",
    "        else:\n",
    "            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
    "        tokens = torch.cat([tokens, next_token], dim=-1)\n",
    "\n",
    "    generation = tokenizer.decode(tokens[0].cpu().numpy().tolist())\n",
    "    return generation\n",
    "\n",
    "\n",
    "prompts = [\n",
    "    # For these prompts, the expected answer is the natural continuation of the prompt\n",
    "    \"I believe the meaning of life is\",\n",
    "    \"Simply put, the theory of relativity states that \",\n",
    "    \"\"\"A brief message congratulating the team on the launch:\n",
    "\n",
    "    Hi everyone,\n",
    "    \n",
    "    I just \"\"\",\n",
    "    # Few shot prompt (providing a few examples before asking model to complete more);\n",
    "    \"\"\"Translate English to French:\n",
    "    \n",
    "    sea otter => loutre de mer\n",
    "    peppermint => menthe poivrée\n",
    "    plush girafe => girafe peluche\n",
    "    cheese =>\"\"\",\n",
    "    \"\"\"He -> Him, She -> Her, They ->\"\"\",\n",
    "    \"\"\"Who is Donald Trump?\"\"\",\n",
    "]\n",
    "\n",
    "for prompt in prompts:\n",
    "    print(\"====================================\")\n",
    "    generated_text = generate_top_p_language(prompt)\n",
    "    print(prompt)\n",
    "    print(generated_text)\n",
    "    print(\"====================================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mRunning cells with 'openlm' requires the ipykernel package.\n",
      "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
      "\u001b[1;31mCommand: 'conda install -n openlm ipykernel --update-deps --force-reinstall'"
     ]
    }
   ],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "modeldiff",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
