{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67e96f86-5b21-495b-b760-894c8e0cf228",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f2b14bc-fdd2-4a6b-be98-13d686e6b4b5",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parametrization\n",
    "from deploy_utils import get_available_gpus\n",
    "\n",
    "#model_name = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "model_name = \"RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8\"\n",
    "model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
    "#model_name = \"RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic\"\n",
    "model_name = \"microsoft/phi-4\"\n",
    "\n",
    "\n",
    "# model instantiation kwargs\n",
    "model_kwargs = {\n",
    "    \"enable_prefix_caching\": True, \n",
    "    \"tensor_parallel_size\": 1, \n",
    "    \"max_model_len\": 4096,\n",
    "    \"max_num_seqs\": 512,\n",
    "}\n",
    "\n",
    "available_gpus =  get_available_gpus(max_utilization=5)\n",
    "\n",
    "# openai api\n",
    "backend_kwargs = dict(\n",
    "    base_url=\"localhost\",\n",
    "    api_key = os.environ.get(\"OPENAI_API_KEY\",\"\")\n",
    ")\n",
    "\n",
    "# vllm backend kwargs\n",
    "backend_kwargs = dict(\n",
    "    base_port=8000,\n",
    "    api_key = os.environ[\"HF_TOKEN\"],\n",
    "    offline=True\n",
    ")\n",
    "backend = \"vllm\"\n",
    "\n",
    "\n",
    "logdir = \"data/temp/\"\n",
    "\n",
    "dataset = \"tot_test_split\"\n",
    "dataset = \"four_digits_unsolvable\"\n",
    "baseline = \"cot\" # io or cot\n",
    "n = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f902099-a1af-477e-b9d8-d74a2f9f0223",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from inference_client import MultiOpenAIClient\n",
    "#os.environ[\"VLLM_CONFIGURE_LOGGING\"] = \"0\"\n",
    "\n",
    "client = MultiOpenAIClient(backend=backend, model_name=model_name, model_kwargs=model_kwargs, **backend_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32f6807d-da66-4f9e-bc5b-0f571730251e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tasks import get_task\n",
    "\n",
    "task = get_task(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a973c162-ba36-46a0-9a47-c728d4d59131",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "client.wait_done()\n",
    "print(\"-\"*100, flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf91189-ca3e-4782-8793-85e0204c572e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.DataFrame({\"root\":task.samples})\n",
    "\n",
    "data[\"convo\"] = data.root.apply(task.io_prompt_wrap if baseline == \"io\" else task.cot_prompt_wrap)\n",
    "data[\"convo\"] = client.chat(messages=data.convo.tolist(), temperature=0.7, n=n, max_tokens=1024, return_format=\"chatml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74581f39-7af5-4257-b397-f73221d8a841",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_flat = data.explode(\"convo\")\n",
    "data_flat[\"convo\"] = data_flat.convo.apply(task.add_finalize_io_answer_turn)\n",
    "data_flat[\"convo\"] = client.chat(messages=data_flat.convo.tolist(), temperature=0.1, n=1, max_tokens=32, return_format=\"chatml\")\n",
    "data_flat[\"solution\"] = data_flat.convo.apply(lambda x: x[0][-1][\"content\"])\n",
    "\n",
    "g = data_flat.groupby(\"root\")\n",
    "result_list = [g.nth(i)[[\"root\",\"solution\"]] for i in tqdm(range(n))]\n",
    "\n",
    "ct = sum(client.completion_token_usage.values()) / n\n",
    "all_logs = [dict(token_usage=ct, model_calls=(n+1) * len(task.samples), results=r) for r in result_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d013954a-8335-4e48-91de-3682198ced77",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, logs in enumerate(all_logs):\n",
    "    with open(f\"{logdir}/{i}.dat\", \"wb\") as f:\n",
    "        pickle.dump(logs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7488c481-112f-46ff-8846-3f2d441c15e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00f7b841-226c-4bc9-b922-61c48bf45723",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.exit()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
