{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import string\n",
    "from algorithm import BaseAlgorithm\n",
    "from datasets import load_dataset, Dataset\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 06-04 21:20:14 llm_engine.py:161] Initializing an LLM engine (v0.4.3) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir='/shared/share_mala/{author}/huggingface/cache', load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B-Instruct)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/user/as6154/.conda/envs/person/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 06-04 21:20:15 weight_utils.py:207] Using model weights format ['*.safetensors']\n",
      "INFO 06-04 21:20:17 model_runner.py:146] Loading model weights took 14.9595 GB\n",
      "INFO 06-04 21:20:18 gpu_executor.py:83] # GPU blocks: 27895, # CPU blocks: 2048\n",
      "INFO 06-04 21:20:20 model_runner.py:854] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
      "INFO 06-04 21:20:20 model_runner.py:858] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
      "INFO 06-04 21:20:24 model_runner.py:924] Graph capturing finished in 4 secs.\n"
     ]
    }
   ],
   "source": [
    "from vllm import LLM, SamplingParams\n",
    "from tqdm import tqdm\n",
    "\n",
    "class MetaLearnAlgorithm(BaseAlgorithm):\n",
    "    \"\"\"\n",
    "    This does not utilize the meta_learning_database, and generates from vllm using previous rows only.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, meta_learning_database: Dataset = None):\n",
    "        super().__init__(meta_learning_database)\n",
    "        self.llm = LLM(\n",
    "            model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
    "            trust_remote_code=True,\n",
    "            tensor_parallel_size=1,\n",
    "            download_dir=\"/shared/share_mala/{author}/huggingface/cache\",\n",
    "            disable_log_stats=True,\n",
    "        )\n",
    "\n",
    "    def rag_prompt_response(self, row: dict) -> str:\n",
    "        pass\n",
    "\n",
    "    def generate_response(self, row: dict) -> str:\n",
    "        prompt = \"Below are some examples of the user's past conversation history.\"\n",
    "        for i in range(int(row[\"user_history_length\"])):\n",
    "            past_prompt = row[\"prompt_\" + str(i + 1)]\n",
    "            past_response = row[\"response_\" + str(i + 1)]\n",
    "            past_reward = str(row[\"reward_\" + str(i + 1)])\n",
    "            prompt += (\n",
    "                \"User: \"\n",
    "                + past_prompt\n",
    "                + \"\\nAssistant: \"\n",
    "                + past_response\n",
    "                + \"\\nReward: \"\n",
    "                + past_reward\n",
    "                + \"\\n\\n\"\n",
    "            )\n",
    "        prompt += \"Use the contexts above to generate a good response for the user prompt below. Stop after answering the User Prompt, don't give a reward.\\n\"\n",
    "        prompt += \"User: \" + row['test_prompt']\n",
    "        sampling_params = SamplingParams(temperature=0.5, top_p=0.95, max_tokens=512)\n",
    "        output = self.llm.generate([prompt], sampling_params)\n",
    "        return output[0].outputs[0].text\n",
    "\n",
    "    def generate_evaluation_responses(self, debug=False) -> Dataset:\n",
    "        dataset = self.eval_dataset\n",
    "        if debug:\n",
    "            dataset = dataset.select(range(3))\n",
    "        responses = [self.generate_response(dataset[i]) for i in tqdm(range(len(dataset)), desc=\"generating_evaluation_response\")]\n",
    "        dataset = dataset.add_column(\"test_response\", responses)\n",
    "        return dataset\n",
    "\n",
    "test_gen = MetaLearnAlgorithm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.29s/it, Generation Speed: 68.35 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.60s/it, Generation Speed: 77.66 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.88s/it, Generation Speed: 75.43 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.57s/it, Generation Speed: 74.36 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.12s/it, Generation Speed: 74.46 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.66s/it, Generation Speed: 76.90 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.09s/it, Generation Speed: 75.69 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.75s/it, Generation Speed: 76.28 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:05<00:00,  5.20s/it, Generation Speed: 75.44 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.81s/it, Generation Speed: 75.20 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.80s/it, Generation Speed: 75.27 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.55s/it, Generation Speed: 74.98 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:05<00:00,  5.37s/it, Generation Speed: 75.50 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.33s/it, Generation Speed: 77.14 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.59s/it, Generation Speed: 75.93 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.64s/it, Generation Speed: 75.49 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.58s/it, Generation Speed: 76.08 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:03<00:00,  3.55s/it, Generation Speed: 73.50 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.04s/it, Generation Speed: 68.62 toks/s]\n",
      "Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.74s/it, Generation Speed: 76.01 toks/s]\n",
      "generating_evaluation_response: 100%|██████████| 20/20 [01:37<00:00,  4.87s/it]\n"
     ]
    }
   ],
   "source": [
    "updated_dataset = test_gen.generate_evaluation_responses(debug=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "person",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
