{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d2170089",
   "metadata": {},
   "source": [
    "## Emulating closed source API\n",
    "\n",
    "Note to run, first setup local model with \"python server.py\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8519c571",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5047ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from transformers import AutoTokenizer\n",
    "import os\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from openai import OpenAI\n",
    "import torch\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "from src.decoding.beam_search import entropy_guided_beam_search_api\n",
    "from src.comparison.compare import format_answer\n",
    "\n",
    "from helpers import dataset_fields, get_dataset, get_model_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6573595f",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(base_url=\"http://localhost:8000/v1\", api_key=\"dummy\") # No key needed for local model\n",
    "top_k = 20\n",
    "max_beam_size = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "008e45b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_cache_dir = \"./.hf_cache\"\n",
    "hf_model_name = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "local_model_dir = f\"./local/models/{hf_model_name}\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    local_model_dir,\n",
    "    cache_dir=hf_cache_dir,\n",
    "    local_files_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d858d65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Datasets: gsm8k, math-500, humaneval, scibench\n",
    "max_examples = 100\n",
    "dataset_name, dataset_subset, dataset_split, text_field, answer_field, choices_field, local_dataset_dir, prompt_generator = dataset_fields(\"math-500\")\n",
    "dataset = get_dataset(local_dataset_dir, dataset_name, dataset_subset, dataset_split, hf_cache_dir, max_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab82f7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Output files\n",
    "# Replace slashes to avoid unintended subdirectories\n",
    "safe_model_name = hf_model_name.replace(\"/\", \"__\")\n",
    "safe_dataset_name = dataset_name.replace(\"/\", \"__\")\n",
    "\n",
    "results_folder = f\"./results/closed/{safe_model_name}/\"\n",
    "\n",
    "# Ensure all necessary directories exist\n",
    "os.makedirs(results_folder, exist_ok=True)\n",
    "\n",
    "def save_results(res, subname):\n",
    "    df = pd.DataFrame(res).T\n",
    "    df.index.name = \"question\"\n",
    "    df.to_csv(f\"{results_folder}{safe_dataset_name}_{subname}.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74da8c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "method_name = \"proposed\"\n",
    "max_beam_size = 5\n",
    "T = 400\n",
    "\n",
    "assert max_beam_size <= top_k, \"Beam size must be less than or equal to top-k\"\n",
    "\n",
    "proposed_info = {}\n",
    "print(f\"Running {method_name}..\")\n",
    "\n",
    "results = defaultdict(dict)\n",
    "exceptions = defaultdict(dict)\n",
    "\n",
    "for idx, example in enumerate(tqdm(dataset, desc=\"Examples\")):    \n",
    "    prompt = prompt_generator(example)\n",
    "    \n",
    "    gt_answer = example.get(answer_field, \"N/A\")\n",
    "    \n",
    "    if prompt not in results:           \n",
    "        results[prompt] = {\n",
    "            \"ground_truth\": gt_answer\n",
    "        }\n",
    "        \n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    input_ids = inputs[\"input_ids\"]\n",
    "\n",
    "    try:\n",
    "        tree = entropy_guided_beam_search_api(client, tokenizer, prompt,\n",
    "                                              max_beam_size=max_beam_size, max_new_tokens=T, top_k=top_k)\n",
    "        \n",
    "        proposed_answer = format_answer(tokenizer, tree.best_path(eos_token_id=tokenizer.eos_token_id)[-1].context)\n",
    "        \n",
    "        proposed_info[prompt] = {\n",
    "            \"mean_branching_factor\": tree.mean_branching_factor(),\n",
    "            \"total_branches\": tree.total_branches(),\n",
    "        }\n",
    "                \n",
    "        for k, v in tree.beam_sizes.items():\n",
    "            proposed_info[prompt][f\"beam_{k}\"] = v\n",
    "            \n",
    "        del tree\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        proposed_answer = \"NA\"\n",
    "        exceptions[prompt][method_name] = str(e)\n",
    "\n",
    "    results[prompt][method_name] = proposed_answer\n",
    "            \n",
    "    # Clear intermediate variables\n",
    "    del inputs, input_ids\n",
    "    \n",
    "    if idx % 10 == 0:\n",
    "        save_results(results, \"results\")\n",
    "        save_results(exceptions, \"exceptions\")\n",
    "        save_results(proposed_info, f\"{method_name}_info\")\n",
    "        \n",
    "save_results(results, \"results\")\n",
    "save_results(exceptions, \"exceptions\")\n",
    "save_results(proposed_info, f\"{method_name}_info\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b43fc87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Just for setting temperature, should really get from server\n",
    "model_default_params = get_model_params(local_model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45e059e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "comparisons = {\n",
    "    \"greedy\": {\"temperature\": 0, \"top_k\": 1},\n",
    "    \"top-k\": {\"temperature\": model_default_params[\"temperature\"], \"top_k\": top_k}\n",
    "}\n",
    "\n",
    "for method_name, method_args in comparisons.items():\n",
    "\n",
    "    for idx, example in enumerate(tqdm(dataset, desc=\"Examples\")):    \n",
    "        prompt = prompt_generator(example)\n",
    "        \n",
    "        gt_answer = example.get(answer_field, \"N/A\")\n",
    "        \n",
    "        if prompt not in results:           \n",
    "            results[prompt] = {\n",
    "                \"ground_truth\": gt_answer\n",
    "            }\n",
    "            \n",
    "        inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "        input_ids = inputs[\"input_ids\"]\n",
    "\n",
    "        try:            \n",
    "            # Query LLM\n",
    "            response = client.responses.create(\n",
    "                model=\"\", # Isnt actually used|\n",
    "                input=prompt,\n",
    "                max_output_tokens=T,\n",
    "                metadata=method_args\n",
    "            )\n",
    "            comparison_answer = response.choices[0][\"content\"][0][\"text\"]\n",
    "                    \n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            comparison_answer = \"NA\"\n",
    "            exceptions[prompt][method_name] = str(e)\n",
    "\n",
    "        results[prompt][method_name] = comparison_answer\n",
    "                \n",
    "        if idx % 10 == 0:\n",
    "            save_results(results, \"results\")\n",
    "            save_results(exceptions, \"exceptions\")\n",
    "            save_results(proposed_info, f\"{method_name}_info\")\n",
    "        \n",
    "save_results(results, \"results\")\n",
    "save_results(exceptions, \"exceptions\")\n",
    "save_results(proposed_info, f\"{method_name}_info\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
