{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MoE-PEFT: An Efficient LLM Fine-Tuning Factory for Mixture of Expert (MoE) Parameter-Efficient Fine-Tuning.\n",
    "[![](https://github.com/TUDB-Labs/MoE-PEFT/actions/workflows/python-test.yml/badge.svg)](https://github.com/TUDB-Labs/MoE-PEFT/actions/workflows/python-test.yml)\n",
    "[![](https://img.shields.io/github/stars/TUDB-Labs/MoE-PEFT?logo=GitHub&style=flat)](https://github.com/TUDB-Labs/MoE-PEFT/stargazers)\n",
    "[![](https://img.shields.io/github/v/release/TUDB-Labs/MoE-PEFT?logo=Github)](https://github.com/TUDB-Labs/MoE-PEFT/releases/latest)\n",
    "[![](https://img.shields.io/pypi/v/moe_peft?logo=pypi)](https://pypi.org/project/moe_peft/)\n",
    "[![](https://img.shields.io/docker/v/mikecovlee/moe_peft?logo=Docker&label=docker)](https://hub.docker.com/r/mikecovlee/moe_peft/tags)\n",
    "[![](https://img.shields.io/github/license/TUDB-Labs/MoE-PEFT)](http://www.apache.org/licenses/LICENSE-2.0)\n",
    "\n",
    "MoE-PEFT is an open-source *LLMOps* framework built on [m-LoRA](https://github.com/TUDB-Labs/mLoRA). It is designed for high-throughput fine-tuning, evaluation, and inference of Large Language Models (LLMs) using techniques such as MoE + Others (like LoRA, DoRA). Key features of MoE-PEFT include:\n",
    "\n",
    "- Concurrent fine-tuning, evaluation, and inference of multiple adapters with a shared pre-trained model.\n",
    "\n",
    "- **MoE PEFT** optimization, mainly for [MixLoRA](https://github.com/TUDB-Labs/MixLoRA) and other MoLE implementation.\n",
    "\n",
    "- Support for multiple PEFT algorithms and various pre-trained models.\n",
    "\n",
    "- Seamless integration with the [HuggingFace](https://huggingface.co) ecosystem.\n",
    "\n",
    "## About this notebook\n",
    "\n",
    "This is a simple jupiter notebook for showcasing the basic process of building chatbot with **Gemma-2 2B**."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clone and install MoE-PEFT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip uninstall torchvision torchaudio -y\n",
    "! pip install moe_peft"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "365a13c3d8654e51ad894b8459a5297c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "import moe_peft\n",
    "\n",
    "base_model = \"google/gemma-2-2b-it\"\n",
    "\n",
    "model = moe_peft.LLMModel.from_pretrained(\n",
    "    base_model,\n",
    "    device=moe_peft.executor.default_device_name(),\n",
    "    load_dtype=torch.bfloat16,\n",
    ")\n",
    "tokenizer = moe_peft.Tokenizer(base_model)\n",
    "\n",
    "model.init_adapter(moe_peft.AdapterConfig(adapter_name=\"default\"))\n",
    "\n",
    "gen_config = moe_peft.GenerateConfig(adapter_name=\"default\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build a chatbot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Markdown\n",
    "import textwrap\n",
    "\n",
    "\n",
    "def display_chat(prompt, text):\n",
    "    formatted_prompt = (\n",
    "        \"<font color='brown'>🙋‍♂️<blockquote>\" + prompt + \"</blockquote></font>\"\n",
    "    )\n",
    "    text = text.replace(\"•\", \"  *\")\n",
    "    text = textwrap.indent(text, \"> \", predicate=lambda _: True)\n",
    "    formatted_text = \"<font color='teal'>🤖\\n\\n\" + text + \"\\n</font>\"\n",
    "    return Markdown(formatted_prompt + formatted_text)\n",
    "\n",
    "\n",
    "def to_markdown(text):\n",
    "    text = text.replace(\"•\", \"  *\")\n",
    "    return Markdown(textwrap.indent(text, \"> \", predicate=lambda _: True))\n",
    "\n",
    "\n",
    "class ChatState:\n",
    "    \"\"\"\n",
    "    Manages the conversation history for a turn-based chatbot\n",
    "    Follows the turn-based conversation guidelines for the Gemma family of models\n",
    "    documented at https://ai.google.dev/gemma/docs/formatting\n",
    "    \"\"\"\n",
    "\n",
    "    __START_TURN_USER__ = \"<start_of_turn>user\\n\"\n",
    "    __START_TURN_MODEL__ = \"<start_of_turn>model\\n\"\n",
    "    __END_TURN__ = \"<end_of_turn>\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model: moe_peft.LLMModel,\n",
    "        tokenizer: moe_peft.Tokenizer,\n",
    "        gen_config: moe_peft.GenerateConfig,\n",
    "        system: str = \"\",\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Initializes the chat state.\n",
    "\n",
    "        Args:\n",
    "            model: The language model to use for generating responses.\n",
    "            system: (Optional) System instructions or bot description.\n",
    "        \"\"\"\n",
    "        self.model = model\n",
    "        self.tokenizer = tokenizer\n",
    "        self.gen_config = gen_config\n",
    "        self.system = system\n",
    "        self.history = []\n",
    "\n",
    "    def add_to_history_as_user(self, message):\n",
    "        \"\"\"\n",
    "        Adds a user message to the history with start/end turn markers.\n",
    "        \"\"\"\n",
    "        self.history.append(\n",
    "            self.__START_TURN_USER__ + message + self.__END_TURN__ + \"\\n\"\n",
    "        )\n",
    "\n",
    "    def add_to_history_as_model(self, message):\n",
    "        \"\"\"\n",
    "        Adds a model response to the history with start/end turn markers.\n",
    "        \"\"\"\n",
    "        self.history.append(self.__START_TURN_MODEL__ + message)\n",
    "\n",
    "    def get_history(self):\n",
    "        \"\"\"\n",
    "        Returns the entire chat history as a single string.\n",
    "        \"\"\"\n",
    "        return \"\".join([*self.history])\n",
    "\n",
    "    def get_full_prompt(self):\n",
    "        \"\"\"\n",
    "        Builds the prompt for the language model, including history and system description.\n",
    "        \"\"\"\n",
    "        prompt = self.get_history() + self.__START_TURN_MODEL__\n",
    "        if len(self.system) > 0:\n",
    "            prompt = self.system + \"\\n\" + prompt\n",
    "        return prompt\n",
    "\n",
    "    def send_message(self, message):\n",
    "        \"\"\"\n",
    "        Handles sending a user message and getting a model response.\n",
    "\n",
    "        Args:\n",
    "            message: The user's message.\n",
    "\n",
    "        Returns:\n",
    "            The model's response.\n",
    "        \"\"\"\n",
    "        self.add_to_history_as_user(message)\n",
    "        prompt = self.get_full_prompt()\n",
    "        self.gen_config.prompts = [prompt]\n",
    "        response = moe_peft.generate(\n",
    "            self.model, self.tokenizer, [self.gen_config], max_gen_len=2048\n",
    "        )[self.gen_config.adapter_name][0]\n",
    "        result = response.replace(prompt, \"\").replace(\n",
    "            self.__END_TURN__, \"\"\n",
    "        )  # Extract only the new response\n",
    "        self.add_to_history_as_model(result)\n",
    "        return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chat with the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<font color='brown'>🙋‍♂️<blockquote>Tell me, in a few words,  how to compute all prime numbers up to 1000?</blockquote></font><font color='teal'>🤖\n",
       "\n",
       "> Sieve of Eratosthenes.\n",
       "> \n",
       "> **Explanation:** This is an efficient algorithm for finding prime numbers.  \n",
       "\n",
       "</font>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chat = ChatState(model, tokenizer, gen_config)\n",
    "message = \"Tell me, in a few words,  how to compute all prime numbers up to 1000?\"\n",
    "display_chat(message, chat.send_message(message))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<font color='brown'>🙋‍♂️<blockquote>Now in Python! No numpy, please!</blockquote></font><font color='teal'>🤖\n",
       "\n",
       "> ```python\n",
       "> def sieve_of_eratosthenes(n):\n",
       ">   \"\"\"Returns a list of prime numbers up to n.\"\"\"\n",
       ">   primes = [True] * (n + 1)\n",
       ">   primes[0] = primes[1] = False\n",
       "> \n",
       ">   for i in range(2, int(n ** 0.5) + 1):\n",
       ">     if primes[i]:\n",
       ">       for j in range(i*i, n + 1, i):\n",
       ">         primes[j] = False\n",
       "> \n",
       ">   return [i for i, is_prime in enumerate(primes) if is_prime]\n",
       "> \n",
       "> print(sieve_of_eratosthenes(1000))\n",
       "> ```\n",
       "> \n",
       "> \n",
       "> **Explanation:**\n",
       "> \n",
       "> 1. **Initialization:** A boolean list `primes` is created with size `n+1`, representing potential primes from 0 to n. Initially, both 0 and 1 are marked as non-primes.\n",
       "> 2. **Iteration:** The loop starts from 2 up to the square root of `n`. We only need to check divisors up to the square root because any composite number has a prime factor less than or equal to its square root. \n",
       "> 3. **Marking Non-Primes:** For each prime `i`, its multiples starting from `i*i` are marked as non-primes in the `primes` array.\n",
       "> 4. **Returning Primes:** Finally, we create a new list by filtering the `primes` array using list comprehension, keeping only those indices where `primes[i]` is True, indicating that the corresponding index corresponds to a prime number. \n",
       "> \n",
       "> \n",
       "> \n",
       "> Let me know if you have any further questions or need additional explanations.\n",
       "</font>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "message = \"Now in Python! No numpy, please!\"\n",
    "display_chat(message, chat.send_message(message))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<font color='brown'>🙋‍♂️<blockquote>Thank you, it works! Can you explain the code in French?</blockquote></font><font color='teal'>🤖\n",
       "\n",
       "> La fonction `sieve_of_eratosthenes(n)` renvoie une liste de nombres premiers jusqu'à `n`.\n",
       "> \n",
       "> ### Explication:\n",
       ">  La fonction utilise la méthode de Sieve d’Erathostène pour trouver les nombres premiers. Voici comment ça marche: \n",
       "> \n",
       "> 1. **Initialisation**: \n",
       ">    -  On crée une liste booléenne `primes` de taille `n+1` représentant des nombres potentiels qui sont premiers. Les éléments initiaux sont tous définis comme `True`.  Nous assignons les valeurs `False` à 0 et 1 car ils ne sont pas premiers.\n",
       "> \n",
       "> 2. **Itération**: \n",
       ">     - On commence par le nombre 2 jusqu'au carré root de `n` (inclusif). Cela signifie que nous allons vérifier seulement les diviseurs potentiels jusqu'à la racine carrée de `n`. \n",
       "> 3. **Marquer les non-premiers**: \n",
       ">     - Pour chaque nombre premier `i`, on vérifie ses multiples (commençant par `i*i`) en plaçant les valeurs de `primes[j]` à `False`.\n",
       "> \n",
       "> \n",
       "> 4. **Retourner les nombres premiers**: \n",
       ">    -  Enfin, un nouvel array est construit en filtrant la liste `primes`  en utilisant la syntaxe de liste comprehension. On sélectionne uniquement les indices correspondant à des nombres premiers où `primes[i]` est `True`. \n",
       "> \n",
       "> \n",
       "> \n",
       "> \n",
       "> \n",
       ">  \n",
       "> \n",
       "> J'espère que cela vous aide à comprendre la fonction plus clairement! Si vous avez des questions supplémentaires, n'hésitez pas à les poser. \n",
       "\n",
       "</font>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "message = \"Thank you, it works! Can you explain the code in French?\"\n",
    "display_chat(message, chat.send_message(message))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<font color='brown'>🙋‍♂️<blockquote>Great! Now add those explanations as comments in the code.</blockquote></font><font color='teal'>🤖\n",
       "\n",
       "> ```python\n",
       "> def sieve_of_eratosthenes(n):\n",
       ">   \"\"\"Returns a list of prime numbers up to n.\"\"\"\n",
       ">   #Initialize a boolean list 'primes' with size n+1 to represent all possible prime numbers.\n",
       ">   primes = [True] * (n + 1) \n",
       ">   primes[0] = primes[1] = False # 0 and 1 are not primes.\n",
       "> \n",
       ">   # Iterate through the integers from 2 up to the square root of n.\n",
       ">   for i in range(2, int(n ** 0.5) + 1): \n",
       ">     if primes[i]:\n",
       ">       # If the current integer 'i' is prime\n",
       ">       # Mark all multiples of 'i' as non-prime by setting their corresponding value in 'primes' to False\n",
       ">       for j in range(i * i, n + 1, i): \n",
       ">           primes[j] = False \n",
       "> \n",
       ">   #Filter the 'primes' list to obtain a list of prime numbers\n",
       ">   return [i for i, is_prime in enumerate(primes) if is_prime]\n",
       "> \n",
       "> print(sieve_of_eratosthenes(1000))\n",
       "> \n",
       "> ```\n",
       "> \n",
       "> \n",
       "> \n",
       "> \n",
       "> Let me know if you would like further details or clarifications! 😄 \n",
       "</font>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "message = \"Great! Now add those explanations as comments in the code.\"\n",
    "display_chat(message, chat.send_message(message))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "moe_peft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
