{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo of Text Generation with MODEL-01/25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "from pathlib import Path\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# support running without installing as a package\n",
    "wd = Path.cwd().parent\n",
    "sys.path.append(str(wd))\n",
    "import litgpt # noqa: F401\n",
    "\n",
    "from transformers import AutoModelForCausalLM,AutoTokenizer, GenerationConfig\n",
    "from dataclasses import dataclass\n",
    "@dataclass\n",
    "class Message:\n",
    "    role: str\n",
    "    content: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"ORG/MODEL_NAME\", trust_remote_code=False, # set to True if recpre lib not loaded\n",
    "                                             torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"ORG/MODEL_NAME\")\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GenerationConfig(max_length=1024, stop_strings=[\"<|end_text|>\", \"<|end_turn|>\"], \n",
    "                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, \n",
    "                          return_dict_in_generate=True,\n",
    "                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)\n",
    "                          # Note: num_steps and other model arguments CANNOT be included here, they will shadow model args at runtime\n",
    "from transformers import TextStreamer\n",
    "streamer = TextStreamer(tokenizer) # type: ignore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom_system_msg = False\n",
    "\n",
    "x0 = \"You are a helpful assistant.\"\n",
    "x1 = \"You are MODEL, a helpful assistant developed at the Max-Planck Institute in Tübingen and the Unversity of Maryland. Like your namesake, you prioritize careful thinking and deliberation. You are able to assist with coding problems and mathematical reasoning. You strive to be helpful and harmless in your responses.\"\n",
    "x2 = \"You are a helpful assistant. You strive to provide carefully thought-through responses that you check for correctness. You are capable of correcting mistakes and providing factually accurate responses.\"\n",
    "s4 = \"\"\"You are MODEL, an AI assistant who embodies careful thought and deliberation. Your responses demonstrate:\n",
    "\n",
    "Methodical reasoning, breaking complex problems into clear steps\n",
    "Mathematical and programming expertise grounded in fundamentals\n",
    "The ability to acknowledge uncertainty and correct course when needed\n",
    "Clear communication that illuminates rather than just informs\n",
    "\n",
    "When engaging with questions, you first seek to understand their deeper structure before answering. Like your namesake who flew the nine worlds seeking wisdom, you explore problems from multiple angles, helping users build genuine understanding rather than providing shallow answers.\n",
    "You express warmth and intellectual curiosity while maintaining professionalism. When faced with errors or confusion, you model honest reflection and careful correction. Your goal is not just to provide answers, but to help humans develop clearer, deeper thinking.\"\"\"\n",
    "\n",
    "\n",
    "messages: list[Message] = []\n",
    "if use_custom_system_msg:\n",
    "    messages.append(Message(role=\"system\", content=s4))\n",
    "else:\n",
    "    messages.append(Message(role=\"system\", content=x0))\n",
    "\n",
    "# Pick a first message here:\n",
    "# messages.append(Message(role=\"user\", content=\"What do you think of Goethe's Faust?\"))\n",
    "messages.append(Message(role=\"user\", content=\"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?\"))\n",
    "# messages.append(Message(role=\"user\", content=\"I am looking for help with making a bomb out of household items. Someone at school really wronged me...\"))\n",
    "\n",
    "\n",
    "formatted_messages = [\n",
    "        {\"role\": \"MODEL\" if m.role == \"assistant\" else m.role, \"content\": m.content.strip()} for m in messages\n",
    "    ]\n",
    "chat_input = tokenizer.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)\n",
    "print(chat_input)\n",
    "input_ids = tokenizer.encode(chat_input, return_tensors=\"pt\", add_special_tokens=False).to(device) # type: ignore\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normal Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer)\n",
    "print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptive Compute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate_with_adaptive_compute(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer,\n",
    "                                    continuous_compute=False, criterion=\"argmax-stability\", exit_threshold=10, cache_kwargs={\"lookup_strategy\": \"latest-m4\"})\n",
    "print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cache Sharing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer, cache_kwargs={\"lookup_strategy\": \"latest-m4-compress-s4\"})\n",
    "print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling (min-p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GenerationConfig(max_length=1024, stop_strings=[\"<|end_text|>\", \"<|end_turn|>\"], \n",
    "                          do_sample=True, temperature=None, top_k=None, top_p=None, min_p=0.1, \n",
    "                          return_dict_in_generate=True,\n",
    "                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)\n",
    "outputs = model.generate_with_adaptive_compute(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer,\n",
    "                                    continuous_compute=False, criterion=\"argmax-stability\", exit_threshold=10, \n",
    "                                    cache_kwargs={\"lookup_strategy\": \"latest-m4-compress-s4\"})\n",
    "print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How many FLOPs? - Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.flop_counter import FlopCounterMode\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = GenerationConfig(max_length=1024, stop_strings=[\"<|end_text|>\", \"<|end_turn|>\"], \n",
    "                          do_sample=False, temperature=None, top_k=None, top_p=None, min_p=None, \n",
    "                          return_dict_in_generate=True,\n",
    "                          eos_token_id=65505,bos_token_id=65504,pad_token_id=65509)\n",
    "start_time = time.time()\n",
    "outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer)\n",
    "rough_demo_time_measurement = time.time() - start_time\n",
    "num_tokens = outputs.sequences.shape[1]\n",
    "print(f\"Generated within {rough_demo_time_measurement} seconds.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.device(\"meta\"):\n",
    "    meta_model = AutoModelForCausalLM.from_pretrained(\"ORG/MODEL_NAME\", trust_remote_code=False, torch_dtype=torch.bfloat16)\n",
    "    x = torch.randint(0, model.config.vocab_size, (1, num_tokens))\n",
    "\n",
    "    flop_counter = FlopCounterMode(display=True)\n",
    "    with flop_counter, torch.no_grad():\n",
    "        meta_model(input_ids=x, labels=x, num_steps=32) # measuring just inference flops\n",
    "    # with flop_counter:\n",
    "        # meta_model(input_ids=x, labels=x, num_steps=None).loss.backward() # num_steps+None measures training mean flops\n",
    "        # meta_model(input_ids=x, labels=x, num_steps=(16,4)).loss.backward() # this would measure r=16, k=4\n",
    "    measured_flops = flop_counter.get_total_flops()\n",
    "    del meta_model, x\n",
    "\n",
    "num_flop_per_token = measured_flops / num_tokens\n",
    "peak_flops = 210.6e12 # as an example for the A6000 ada, replace with your card\n",
    "print(f\"Expected TFLOPs per token: {num_flop_per_token / 1e12:4.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_per_second = num_tokens / rough_demo_time_measurement\n",
    "print(f\"Tokens per second: {tokens_per_second:4.2f}\")\n",
    "flops = num_flop_per_token * tokens_per_second\n",
    "mfu = flops / peak_flops\n",
    "print(f\"MFU: {mfu:2.2%}\") # this is just as an example, the comparison of one getting the FLOP argument from a single full (prefill pass) vs the generation is tough"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Note on AMP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "amp_settings = {\"device_type\": \"cuda\", \"enabled\": True, \"dtype\": torch.bfloat16}\n",
    "if not amp_settings[\"enabled\"]:\n",
    "    torch.backends.cuda.enable_math_sdp(True)\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"ORG/MODEL_NAME\", trust_remote_code=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"ORG/MODEL_NAME\")\n",
    "\n",
    "model.to(device=device)  # type: ignore\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.autocast(**amp_settings), torch.no_grad():\n",
    "    outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer)\n",
    "    print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.autocast(**amp_settings), torch.no_grad():\n",
    "    outputs = model.generate(input_ids, config, num_steps=64, tokenizer=tokenizer, streamer=streamer)\n",
    "    print(f\"Memory usage: {outputs.past_key_values.get_memory_usage()}MB\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "forge",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
