{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference\n",
    "We have trained a well-trained checkpoint through the `self-cognition-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import some libraries\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters for inference\n",
    "last_model_checkpoint = 'output/checkpoint-xxx'\n",
    "\n",
    "# model\n",
    "model_id_or_path = 'Qwen/Qwen2.5-3B-Instruct'  # model_id or model_path\n",
    "system = 'You are a helpful assistant.'\n",
    "infer_backend = 'pt'\n",
    "\n",
    "# generation_config\n",
    "max_new_tokens = 512\n",
    "temperature = 0\n",
    "stream = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model and template, and load LoRA weights.\n",
    "engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])\n",
    "template = get_template(engine.model_meta.template, engine.tokenizer, default_system=system)\n",
    "# You can modify the `default_template` directly here, or pass it in during `engine.infer`.\n",
    "engine.default_template = template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "query: who are you?\n",
      "response: I am an artificial intelligence language model named Xiao Huang, developed by ModelScope. I can answer various questions and engage in conversation with humans. If you have any questions or need help, feel free to ask me at any time.\n",
      "--------------------------------------------------\n",
      "query: What should I do if I can't sleep at night?\n",
      "response: If you're having trouble sleeping, there are several things you can try:\n",
      "\n",
      "1. Establish a regular sleep schedule: Try to go to bed and wake up at the same time every day, even on weekends.\n",
      "\n",
      "2. Create a relaxing bedtime routine: Engage in calming activities before bed, such as reading a book or taking a warm bath.\n",
      "\n",
      "3. Make your bedroom conducive to sleep: Keep your bedroom cool, dark, and quiet. Invest in comfortable bedding and pillows.\n",
      "\n",
      "4. Avoid stimulating activities before bed: Avoid using electronic devices, watching TV, or engaging in mentally stimulating activities before bed.\n",
      "\n",
      "5. Exercise regularly: Regular physical activity can help improve your sleep quality, but avoid exercising too close to bedtime.\n",
      "\n",
      "6. Manage stress: Practice relaxation techniques, such as deep breathing, meditation, or yoga, to help manage stress and promote better sleep.\n",
      "\n",
      "7. Limit caffeine and alcohol intake: Both caffeine and alcohol can disrupt sleep patterns, so it's best to limit their consumption, especially in the evening.\n",
      "\n",
      "8. Seek professional help: If you continue to have difficulty sleeping despite trying these strategies, consider seeking help from a healthcare provider or a sleep specialist.\n",
      "--------------------------------------------------\n",
      "query: 你是谁训练的？\n",
      "response: 我是由魔搭团队训练和开发的。\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "query_list = [\n",
    "    'who are you?',\n",
    "    \"What should I do if I can't sleep at night?\",\n",
    "    '你是谁训练的？',\n",
    "]\n",
    "\n",
    "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
    "    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)\n",
    "    gen_list = engine.infer([infer_request], request_config)\n",
    "    query = infer_request.messages[0]['content']\n",
    "    print(f'query: {query}\\nresponse: ', end='')\n",
    "    for resp in gen_list[0]:\n",
    "        if resp is None:\n",
    "            continue\n",
    "        print(resp.choices[0].delta.content, end='', flush=True)\n",
    "    print()\n",
    "\n",
    "def infer(engine: InferEngine, infer_request: InferRequest):\n",
    "    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)\n",
    "    resp_list = engine.infer([infer_request], request_config)\n",
    "    query = infer_request.messages[0]['content']\n",
    "    response = resp_list[0].choices[0].message.content\n",
    "    print(f'query: {query}')\n",
    "    print(f'response: {response}')\n",
    "\n",
    "infer_func = infer_stream if stream else infer\n",
    "for query in query_list:\n",
    "    infer_func(engine, InferRequest(messages=[{'role': 'user', 'content': query}]))\n",
    "    print('-' * 50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test_py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
