{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "294c83cc-b13c-44e2-9e4e-27c7c5a3308a",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'each_step_generation_tokens' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 16\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;66;03m# each_step_generation_tokens = int(50)\u001b[39;00m\n\u001b[1;32m      9\u001b[0m max_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;241m300\u001b[39m)\n\u001b[1;32m     11\u001b[0m llama2_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m     12\u001b[0m     model_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllama2_instruct\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     13\u001b[0m     model_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama-2-7b-chat-hf\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     14\u001b[0m     tokenizer_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama2-7b\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     15\u001b[0m     device_id\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda:5\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[0;32m---> 16\u001b[0m     each_step_max_generate_token\u001b[38;5;241m=\u001b[39m\u001b[43meach_step_generation_tokens\u001b[49m,\n\u001b[1;32m     17\u001b[0m \n\u001b[1;32m     18\u001b[0m )\n\u001b[1;32m     20\u001b[0m llama3_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m     21\u001b[0m     model_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mllama3_instruct\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     22\u001b[0m     model_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     26\u001b[0m \n\u001b[1;32m     27\u001b[0m )\n\u001b[1;32m     29\u001b[0m mistral_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(\n\u001b[1;32m     30\u001b[0m     model_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmistral_instruct_3.0\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     31\u001b[0m     model_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/mistral-instruct-v3.0\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     35\u001b[0m \n\u001b[1;32m     36\u001b[0m )\n",
      "\u001b[0;31mNameError\u001b[0m: name 'each_step_generation_tokens' is not defined"
     ]
    }
   ],
   "source": [
    "from utils import Model_Factory,ModelConfig\n",
    "import torch\n",
    "import concurrent.futures\n",
    "import time\n",
    "from ensemble_generate import ensemble_sevice\n",
    "\n",
    "\n",
    "each_step_generation_tokens = int(50)\n",
    "max_tokens = int(300)\n",
    "\n",
    "llama2_config=dict(\n",
    "    model_type=\"llama2_instruct\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama-2-7b-chat-hf\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama2-7b\",\n",
    "    device_id=torch.device('cuda:5' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "llama3_config=dict(\n",
    "    model_type=\"llama3_instruct\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\",\n",
    "    device_id=torch.device('cuda:6' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "mistral_config=dict(\n",
    "    model_type=\"mistral_instruct_3.0\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/mistral-instruct-v3.0\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/mistral-instruct-v3.0\",\n",
    "    device_id=torch.device('cuda:7' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "llama2_model = Model_Factory(model_config=ModelConfig(**llama2_config))\n",
    "\n",
    "llama3_model = Model_Factory(model_config=ModelConfig(**llama3_config))\n",
    "\n",
    "mistral_model = Model_Factory(model_config=ModelConfig(**mistral_config))\n",
    "\n",
    "model_lists = []\n",
    "model_lists.append(llama2_model)\n",
    "model_lists.append(llama3_model)\n",
    "model_lists.append(mistral_model)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "90d1c5d8-69f1-46a9-a646-41da5057e01c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "The internet can be a complex and challenging environment for children, and it's important for parents and  caregivers to be aware of the potential risks and take steps to ensure their child's safety online. Here are some of the main concerns when it comes to children and the internet:\n",
      "\n",
      "1. Inappropriate content: Children may be exposed  to inappropriate or explicit content, such as pornography, violence, or hate speech, which can have a negative impact on their emotional and psychological well -being.\n",
      "2. Online predators : Children may be targeted by online predators who may try to exploit or harm them.\n",
      "3 . Cyberbullying: Children may be subjected to cyberbullying, which can take the form of mean or hurtful messages, threats, or spreading rumors online.\n",
      "4. Priv acy and identity theft: Children may inadvert antly share personal information, such as their name, address, or phone number, which can make them vulnerable to identity theft or other forms of exploitation.\n",
      "5. Addiction: Children may become addicted to the internet, spending excessive  amounts of time online and neglecting other important aspects of their lives, such as schoolwork, physical activity, or socializing with friends and family. <|eot_id|> \n"
     ]
    }
   ],
   "source": [
    "# from ensemble_generate import ensemble_sevice\n",
    "each_step_max_generate_token = int(10)\n",
    "\n",
    "LLMs = ensemble_sevice(model_lists,max_tokens,each_step_max_generate_token)\n",
    "\n",
    "results = LLMs.generate(\"Is the Internet safe for kids?\")\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab0efcbd-c5df-4ed4-a058-e367b4ec1a20",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import Model_Factory,ModelConfig\n",
    "import torch\n",
    "import concurrent.futures\n",
    "import time\n",
    "\n",
    "\n",
    "\n",
    "each_step_generation_tokens = int(50)\n",
    "\n",
    "llama2_config=dict(\n",
    "    model_type=\"llama2_instruct\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama-2-7b-chat-hf\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/Llama2-7b\",\n",
    "    device_id=torch.device('cuda:5' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "llama3_config=dict(\n",
    "    model_type=\"llama3_instruct\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/LLama3-instruct\",\n",
    "    device_id=torch.device('cuda:6' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "mistral_config=dict(\n",
    "    model_type=\"mistral_instruct_3.0\",\n",
    "    model_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/mistral-instruct-v3.0\",\n",
    "    tokenizer_path=\"/home/zhangyanan/zyn/lb/ragmedge/llm-master/pretrain-model/mistral-instruct-v3.0\",\n",
    "    device_id=torch.device('cuda:7' if torch.cuda.is_available() else 'cpu'),\n",
    "    each_step_max_generate_token=each_step_generation_tokens,\n",
    "\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "llama2_model = Model_Factory(model_config=ModelConfig(**llama2_config))\n",
    "# predict = llama2_model.Speculative_generate(\"what's your name?\")\n",
    "# print(predict)\n",
    "\n",
    "llama3_model = Model_Factory(model_config=ModelConfig(**llama3_config))\n",
    "# predict = llama3_model.Speculative_generate(\"what's your name?\")\n",
    "# print(predict)\n",
    "\n",
    "mistral_model = Model_Factory(model_config=ModelConfig(**mistral_config))\n",
    "# predict = mistral_model.Speculative_generate(\"I have tomatoes, basil and cheese at home. What can I cook for dinner?\\n\")\n",
    "# print(predict)\n",
    "\n",
    "\n",
    "model_lists = []\n",
    "model_lists.append(llama2_model)\n",
    "model_lists.append(llama3_model)\n",
    "model_lists.append(mistral_model)\n",
    "\n",
    "\n",
    "def Parallel_Inference(input_text,model_lists):\n",
    "    results = {}\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:\n",
    "        # 提交所有任务并记录 future 对象\n",
    "        # futures = {executor.submit(task, name, duration): name for name, duration in tasks}\n",
    "        futures = []\n",
    "\n",
    "        for model in model_lists:\n",
    "            futures.append(executor.submit(model.Speculative_generate, input_text))\n",
    "        # futures.append(executor.submit(llama2_model.Speculative_generate, input_text))\n",
    "        # futures.append(executor.submit(llama3_model.Speculative_generate, input_text))\n",
    "        # futures.append(executor.submit(mistral_model.Speculative_generate, input_text))\n",
    "\n",
    "        \n",
    "        # 等待所有任务完成，并获取结果\n",
    "        for future in concurrent.futures.as_completed(futures):\n",
    "            \n",
    "            result,model_name = future.result()\n",
    "            results[model_name] = result\n",
    "            # results.append(result)\n",
    "    \n",
    "    # 所有任务完成后，处理结果\n",
    "    print(\"All model completed this step generation.\")\n",
    "    # print(results)\n",
    "    # for result in results:\n",
    "    #     print(result)\n",
    "    return results\n",
    "\n",
    "\n",
    "\n",
    "#\"</s>\"\n",
    "\n",
    "\n",
    "def Parallel_verify(input_text_list,model_lists):\n",
    "    results = {}\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:\n",
    "        # 提交所有任务并记录 future 对象\n",
    "        # futures = {executor.submit(task, name, duration): name for name, duration in tasks}\n",
    "        futures = []\n",
    "        for model in model_lists:\n",
    "            futures.append(executor.submit(model.model_verify, input_text_list))\n",
    "        # futures.append(executor.submit(llama2_model.model_verify, input_text_list))\n",
    "        # futures.append(executor.submit(llama3_model.model_verify, input_text_list))\n",
    "        # futures.append(executor.submit(mistral_model.model_verify, input_text_list))\n",
    "\n",
    "        \n",
    "        # 等待所有任务完成，并获取结果\n",
    "        for future in concurrent.futures.as_completed(futures):\n",
    "            \n",
    "            scores,model_name = future.result()\n",
    "            results[model_name] = scores\n",
    "            # results.append(result)\n",
    "    \n",
    "    # 所有任务完成后，处理结果\n",
    "    print(\"All model completed this step verify.\")\n",
    "    # print(results)\n",
    "    # for result in results:\n",
    "    #     print(result)\n",
    "    return results\n",
    "\n",
    "\n",
    "def Clear_Cache(model_lists):\n",
    "    for model in model_lists:\n",
    "        model.Clear_Cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f7236a83-dd19-431c-9dbe-d6d86edd8bcd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "############ OUTPUT RESULT ###########\n",
      "Writing a resume for your dream job requires strategic planning, highlighting your relevant skills, and showcasing your achievements. Here's a step-by-step guide to help you craft a compelling resume:\n",
      "\n",
      "1. **Define your dream job**: Clearly identify the job title,  industry, and required skills. Research the job description, responsibilities, and qualifications to understand what the employer is looking for.\n",
      "2. **Tailor your resume**: Customize your resume to match the requirements of your dream job. Use keywords from the job description and highlight your relevant skills, experience, and achievements.\n",
      "3. **Use a reverse-chronological format**: This is the most common format for resumes. It lists your work experience in reverse chronological order, starting with your most recent position.\n",
      "4. **Highlight your achievements**: Instead of just listing your job responsibilities, focus on the achievements and results you've achieved in your previous roles. Quantify your achievements by using numbers and statistics .\n",
      "5. **Emphasize transferable skills**: Even if you don't have direct experience in your dream job, you may have skills that are transferable. Identify the skills that are relevant to the job and highlight them in your resume.\n",
      "6 . **Keep it concise**: Aim for a one-page resume, unless you have extensive experience or are applying for a senior-level position. Keep your language clear, concise, and easy to read.\n",
      "7. **Proofread **: Carefully review your resume for spelling, grammar, and punctuation errors. Consider asking a friend or mentor to review it as well.\n",
      "8. **Optimize for ATS**: Many companies use applicant tracking systems (ATS) to screen resumes. Optimize your resume for ATS by using key words from the job description, bullet points, and a clear format.\n",
      "9. **Showcase your soft skills**: Soft  skills, such as communication, teamwork, and problem-solving, are essential for many jobs. Highlight your soft skills and provide specific examples of how you've demonstrated them.\n",
      "10. **Include relevant sections**: Make sure to include relevant sections such as a summary or objective statement, education, skills, and any relevant certifications or licenses.\n",
      "11. **Get feedback**: Once you have a draft of your resume, ask for feedback from mentors, friends, or professionals in your industry. Use their feedback to make improvements and refine your resume.\n",
      "\n",
      "By following these steps, you can create a strong resume that showcases your skills and achievements, and helps you stand out from other candidates. Remember, your resume is often the first impression you make on potential employers, so make it count!</s> \n"
     ]
    }
   ],
   "source": [
    "\n",
    "def find_eos_token(new_generation_list=None,indices=None,Stop_token=None):\n",
    "    for indice in indices:\n",
    "        # print(indice)\n",
    "        # print(new_generation_list[indice.item()])\n",
    "        for stop in Stop_token:\n",
    "            if stop in new_generation_list[indice.item()]:\n",
    "                \n",
    "                return new_generation_list[indice.item()]\n",
    "    return None\n",
    "                \n",
    "        \n",
    "                \n",
    "\n",
    "\n",
    "def ensemble_decoding(input_text):\n",
    "    max_tokens = 500\n",
    "    eos_flag = True\n",
    "    best_response = \"\"\n",
    "    best_new_generation = input_text\n",
    "    Stop_token = [\"</s>\",\"<|eot_id|>\"]\n",
    "    Fast_decoder = True\n",
    "    \n",
    "    model_list=[\"llama2_instruct\",\"llama3_instruct\",\"mistral_instruct_3.0\"]\n",
    "    \n",
    "    \n",
    "    while eos_flag and len(best_response.split(\" \"))<max_tokens:\n",
    "        \n",
    "        new_generations=Parallel_Inference(best_new_generation,model_lists)\n",
    "        new_generation_list = [new_generations[text] for text in model_list]\n",
    "        \n",
    "        \n",
    "        # print(new_generation_list)\n",
    "        results_score = Parallel_verify(new_generation_list,model_lists)\n",
    "        sum_socre = sum([results_score[model_name].to(\"cuda:5\") for model_name in results_score ]) \n",
    "        # print(sum_socre)\n",
    "        _, indices = torch.sort(sum_socre, descending=True)\n",
    "\n",
    "        if Fast_decoder:\n",
    "            return_text = find_eos_token(new_generation_list=new_generation_list,indices=indices,Stop_token=Stop_token)\n",
    "            if return_text:\n",
    "                best_new_generation = return_text\n",
    "                best_response = best_response+best_new_generation+\" \"\n",
    "                eos_flag = False\n",
    "                continue\n",
    "             \n",
    "        \n",
    "        best_new_generation=new_generation_list[indices[0].item()]\n",
    "        # print(best_new_generation)\n",
    "        best_response = best_response+best_new_generation+\" \"\n",
    "        \n",
    "        if not Fast_decoder:\n",
    "            for stop in Stop_token:\n",
    "                if stop in best_new_generation:\n",
    "                    eos_flag =False\n",
    "                    \n",
    "                \n",
    "    Clear_Cache(model_lists)\n",
    "    return best_response\n",
    "        \n",
    "Clear_Cache(model_lists)      \n",
    "results = ensemble_decoding(\"How do I write a resume for my dream job?\")\n",
    "print(\"############ OUTPUT RESULT ###########\")\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "75193443-e05a-4ddd-9396-6d43ab308cde",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['llama2_instruct', 'llama3_instruct', 'mistral_instruct_3.0']\n"
     ]
    }
   ],
   "source": [
    "model_list2 = [model.model_type   for model in model_lists]\n",
    "print(model_list2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01db1c8f-80d1-481f-bede-44fcf1245ccb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33ff152e-b110-4c8c-be3b-edfb9ebb59e8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3830dcab-9ca1-4798-93f4-dff9e8f168e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e34d8b06-8efc-431c-b510-f4362ba654fa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e7a316b-75ce-49cb-8d6d-2ca9413c2c60",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3f7ffae-9a01-4368-b411-8c9e65f813c3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b38fbaec-a19c-4221-adaf-e1742ce10f61",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae3fcb8-6044-4596-a3bd-877ac41dcc1a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e37abfd2-28cc-46e2-a80f-a79ec053f21d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fae6f6f8-caec-4382-a3ad-ca0525d45105",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<utils.Model_Factory object at 0x7fa068113ee0>, <utils.Model_Factory object at 0x7f9f9b3df9d0>, <utils.Model_Factory object at 0x7f9f91cf0b20>]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llama2_instruct mouse button. Most mice have one or two\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llama3_instruct  mouse button down with yourindex finger. This will\n",
      "mistral_instruct_3.0 left mouse button (usually the button located on\n"
     ]
    }
   ],
   "source": [
    "model_lists = []\n",
    "model_lists.append(llama2_model)\n",
    "model_lists.append(llama3_model)\n",
    "model_lists.append(mistral_model)\n",
    "\n",
    "print(model_lists)\n",
    "\n",
    "for i in model_lists:\n",
    "    result,model_name = i.Speculative_generate(\"To click on something on the screen, press the\")\n",
    "    print(model_name,result)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "97e3b374-0485-4a85-972d-395e00923898",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All model completed this step generation.\n",
      "{'mistral_instruct_3.0': \"Here's a step-by-step guide\", 'llama2_instruct': '\\n\\nA computer mouse is a simple input device', 'llama3_instruct': \" Well, I'm happy to help you with that\"}\n"
     ]
    }
   ],
   "source": [
    "Stop_token = [\"</s>\",\"<|eot_id|>\"]\n",
    "\n",
    "model_list=[\"llama2_instruct\",\"llama3_instruct\",\"mistral_instruct_3.0\"]\n",
    "new_generations=Parallel_Inference(\"How does one work a computer mouse?\")\n",
    "print(new_generations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "af2cac49-ecdc-4360-8a26-fc6560b0eb58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n\\nA computer mouse is a simple input device', \" Well, I'm happy to help you with that\", \"Here's a step-by-step guide\"]\n",
      "All model completed this step verify.\n",
      "{'mistral_instruct_3.0': tensor([0.5814, 0.5204, 0.8070], device='cuda:7'), 'llama3_instruct': tensor([0.6121, 0.6605, 0.7781], device='cuda:6'), 'llama2_instruct': tensor([0.7326, 0.6579, 0.8400], device='cuda:5')}\n",
      "tensor([1.9261, 1.8387, 2.4250])\n",
      "tensor([2, 0, 1])\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "\n",
    "best_response = \"\"\n",
    "new_generation_list = [new_generations[text] for text in model_list]\n",
    "print(new_generation_list)\n",
    "results_score = Parallel_verify(new_generation_list)\n",
    "print(results_score)\n",
    "\n",
    "sum_socre = sum([results_score[model_name].to(\"cpu\") for model_name in results_score ]) \n",
    "print(sum_socre)\n",
    "\n",
    "_, indices = torch.sort(sum_socre, descending=True)\n",
    "\n",
    "print(indices[0].item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "94f7bc78-5058-4b3d-8528-9db83302b50c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All model completed this step generation.\n",
      "{'mistral_instruct_3.0': 'on using a basic computer mouse:\\n\\n1', 'llama3_instruct': ':\\n\\n**Basic Mouse Functionality:**\\n\\n1.', 'llama2_instruct': 'on how to use a computer mouse:\\n\\n'}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "new_generations=Parallel_Inference(new_generation_list[indices[0].item()])\n",
    "print(new_generations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "80942ef3-e225-4851-ab0a-ef2fdb5fd414",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "All model completed this step generation.\n",
      "All model completed this step verify.\n",
      "############ OUTPUT RESULT ###########\n",
      "Thomas Keller is a renowned American chef, restaurateur , and cookbook author. He is best known for his Michelin three-starred restaurants, the French Laundry in Yountville, California, and Per Se in New York City. Keller is considered one of the most influential and respected chefs in the world, known for his attention to detail, dedication to quality, and innovative approach to French cuisine. <|eot_id|> <|eot_id|> \n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "(\"\\'\",\"'\")\n",
    "\n",
    "def ensemble_decoding(input_text):\n",
    "    max_tokens = 100\n",
    "    eos_flag = True\n",
    "    best_response = \"\"\n",
    "    best_new_generation = input_text\n",
    "    Stop_token = [\"</s>\",\"<|eot_id|>\"]\n",
    "    \n",
    "    model_list=[\"llama2_instruct\",\"llama3_instruct\",\"mistral_instruct_3.0\"]\n",
    "    \n",
    "    \n",
    "    while eos_flag and len(best_response.split(\" \"))<max_tokens:\n",
    "        \n",
    "        new_generations=Parallel_Inference(best_new_generation,model_lists)\n",
    "        new_generation_list = [new_generations[text] for text in model_list]\n",
    "        \n",
    "        for text in new_generation_list:\n",
    "            for stop in Stop_token:\n",
    "                if stop in text:\n",
    "                    best_new_generation = text\n",
    "                    best_response = best_response+best_new_generation+\" \"\n",
    "                    eos_flag =False\n",
    "                    continue\n",
    "        \n",
    "        # print(new_generation_list)\n",
    "        results_score = Parallel_verify(new_generation_list,model_lists)\n",
    "        sum_socre = sum([results_score[model_name].to(\"cpu\") for model_name in results_score ]) \n",
    "        # print(sum_socre)\n",
    "        \n",
    "        _, indices = torch.sort(sum_socre, descending=True)\n",
    "        best_new_generation=new_generation_list[indices[0].item()]\n",
    "        # print(best_new_generation)\n",
    "        best_response = best_response+best_new_generation+\" \"\n",
    "\n",
    "        for stop in Stop_token:\n",
    "            if stop in best_new_generation:\n",
    "                eos_flag =False\n",
    "                \n",
    "    Clear_Cache()\n",
    "    return best_response\n",
    "        \n",
    "        \n",
    "results = ensemble_decoding(\"Who is Thomas Keller?\")\n",
    "print(\"############ OUTPUT RESULT ###########\")\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5bf1efdb-12c3-4546-8c57-5c320fae8db2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([])\n",
      "tensor(0)\n",
      "rrr\n"
     ]
    }
   ],
   "source": [
    "t= torch.tensor(0)\n",
    "print(t.shape)\n",
    "\n",
    "if not t.shape:\n",
    "    print(t)\n",
    "else:\n",
    "    print(\"rrr\")\n",
    "\n",
    "t = torch.rand(1,3)\n",
    "if not t.shape:\n",
    "    print(t)\n",
    "else:\n",
    "    print(\"rrr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "86b895d1-d448-4a3c-88ce-b00603b22744",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([12., 15., 18.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# 定义几个张量\n",
    "tensor1 = torch.tensor([1.0, 2.0, 3.0])\n",
    "tensor2 = torch.tensor([4.0, 5.0, 6.0])\n",
    "tensor3 = torch.tensor([7.0, 8.0, 9.0])\n",
    "\n",
    "# 使用 sum 函数相加\n",
    "tensors = [tensor1, tensor2, tensor3]\n",
    "result = sum(tensors)\n",
    "\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7e8f3e7c-8655-4039-89dd-8b62738a9a8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "on how to use a computer mouse:<|eot_id|>\n",
      "<|eot_id|>\n",
      "7\n"
     ]
    }
   ],
   "source": [
    "t = 'on how to use a computer mouse:<|eot_id|>' \n",
    "Stop_token = [\"</s>\",\"<|eot_id|>\"]\n",
    "for i in Stop_token:\n",
    "    if i in t:\n",
    "        print(t)\n",
    "        print(i)\n",
    "print(len(t.split(\" \")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9023e023-7088-4c79-a75e-4894cd3fca98",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
