{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b24e64e-21b2-4e97-bb6c-51efb0242d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "872d1596-d59c-45ad-a6a3-522055b87fca",
   "metadata": {
    "editable": true,
    "scrolled": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parametrization\n",
    "from deploy_utils import get_available_gpus\n",
    "\n",
    "#model_name = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "model_name = \"RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8\"\n",
    "model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
    "model_name = \"RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic\"\n",
    "model_name = \"microsoft/phi-4\"\n",
    "\n",
    "\n",
    "# model instantiation kwargs\n",
    "model_kwargs = {\n",
    "    \"enable_prefix_caching\": True, \n",
    "    \"tensor_parallel_size\": 1, \n",
    "    \"max_model_len\": 4096,\n",
    "    \"max_num_seqs\": 100 if \"70B\" in model_name else 500,\n",
    "}\n",
    "\n",
    "if \"mistralai\" in model_name:\n",
    "    model_kwargs[\"tokenizer_mode\"] = \"mistral\"\n",
    "    model_kwargs[\"config_format\"] = \"mistral\"\n",
    "    model_kwargs[\"load_format\"] = \"mistral\"\n",
    "\n",
    "available_gpus =  get_available_gpus(max_utilization=5)\n",
    "\n",
    "# openai api\n",
    "backend_kwargs = dict(\n",
    "    base_url=\"localhost\",\n",
    "    api_key = os.environ.get(\"OPENAI_API_KEY\",\"\")\n",
    ")\n",
    "\n",
    "# vllm backend kwargs\n",
    "backend_kwargs = dict(\n",
    "    base_port=8000,\n",
    "    api_key = os.environ[\"HF_TOKEN\"],\n",
    "    offline=True\n",
    ")\n",
    "backend = \"vllm\"\n",
    "\n",
    "logdir = \"data/temp/\"\n",
    "\n",
    "dataset = \"tot_test_split\"\n",
    "dataset = \"four_digits_unsolvable\"\n",
    "\n",
    "do_verify = False\n",
    "n_verify_sample = 3\n",
    "\n",
    "n_propose_sample = 1\n",
    "n_evaluate_sample = 3\n",
    "n_select_sample = 3\n",
    "\n",
    "base_port = 8000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d8ca5e8-2ac5-4080-a8f9-ba5e2cf2e679",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from inference_client import MultiOpenAIClient\n",
    "#os.environ[\"VLLM_CONFIGURE_LOGGING\"] = \"0\"\n",
    "\n",
    "client = MultiOpenAIClient(backend=backend, model_name=model_name, model_kwargs=model_kwargs, **backend_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ba8d311-d067-41e9-bb53-51568d1b8447",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "client.wait_done()\n",
    "clients = client.clients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebbfc83d-d7b0-408b-ae7e-15312f399f12",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import concurrent.futures\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import time\n",
    "from tot.models import gpt_usage, reset_usage\n",
    "import argparse\n",
    "from tot.bfs import solve\n",
    "from tot import get_task\n",
    "import networkx as nx\n",
    "import pickle\n",
    "\n",
    "def solve_baseline(task, args, start, end):\n",
    "    start_time = time.time()\n",
    "    \n",
    "    outputs = []\n",
    "    with concurrent.futures.ThreadPoolExecutor(max_workers=end - start) as executor:\n",
    "        futures = []\n",
    "        for i in range(start, end):\n",
    "            client = clients[i % len(clients)]\n",
    "            futures.append(executor.submit(solve, args, task, i, to_print=False, client=client))\n",
    "        for i, future in tqdm(enumerate(futures, start)):\n",
    "            outputs.append((i, future.result()))\n",
    "\n",
    "    diff_time = time.time() - start_time\n",
    "    results = np.array([[task.test_output(i, y)[\"r\"] for y in ys] for i, (ys, infos) in outputs])\n",
    "    ct, pt, cc = gpt_usage()\n",
    "    print(diff_time, results[:, 0].mean(), pt, ct, cc)\n",
    "    d = dict(results=results, outputs=outputs, prompt_tokens=pt, completion_tokens=ct, model_calls=cc, runtime=diff_time, args=args)\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc442cb6-2e2e-43ff-abcc-541857c1cef0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_workers = 80\n",
    "temp = 0.7\n",
    "\n",
    "args = argparse.Namespace(backend=model_name, temperature=temp, naive_run=False, prompt_sample=None, method_generate='propose', \n",
    "                   method_evaluate='value', method_select='greedy', n_generate_sample=n_propose_sample, n_evaluate_sample=n_evaluate_sample, \n",
    "                   n_select_sample=n_select_sample, do_verify=do_verify, n_verify_sample=n_verify_sample,\n",
    "                   max_workers=max_workers, dataset=dataset)\n",
    "\n",
    "try:\n",
    "\n",
    "    task = get_task(dataset)\n",
    "    start = 0\n",
    "    end = len(task.samples)\n",
    "\n",
    "    results = solve_baseline(task, args, start, end)\n",
    "\n",
    "except Exception as e:\n",
    "    results = e\n",
    "\n",
    "with open(logdir + \"0.dat\", \"wb\") as f:\n",
    "    pickle.dump(results, f)\n",
    "\n",
    "reset_usage()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8404083-20e4-4b78-ac00-7dab5f4ce769",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results[\"results\"][:,0].mean())\n",
    "print(results[\"results\"].any(axis=1).mean())\n",
    "print(results[\"runtime\"]/3600)\n",
    "print(results[\"completion_tokens\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6db3d7cc-deb7-4e9f-8b1b-03439cfed876",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hp_env2",
   "language": "python",
   "name": "hp_env2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
