{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "af8faf1c-f143-46a8-af1c-bbea5e80a0f4",
   "metadata": {},
   "source": [
    "# Dynamic branching for information-driven contextual decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67967611",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a119fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Memory managememnt\n",
    "import os\n",
    "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f39f6c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "torch._dynamo.config.cache_size_limit = 64  # (needed for Gemma)\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.decoding.beam_search import entropy_guided_beam_search\n",
    "from src.comparison.compare import format_answer, compare_batch\n",
    "\n",
    "from helpers import cleanup, dataset_fields, get_dataset, get_model_and_tokenizer, get_model_params\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a494813d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488c9e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Datasets: gsm8k, math-500, humaneval, scibench\n",
    "dataset_name, dataset_subset, dataset_split, text_field, answer_field, choices_field, local_dataset_dir, prompt_generator = dataset_fields(\"gsm8k\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f014594",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_examples = None\n",
    "max_new_tokens = 400 # From GSM8k paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77cbd780",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_model_name = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "local_model_dir = f\"./local/models/{hf_model_name}\"\n",
    "hf_cache_dir = \"./.hf_cache\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1d84ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Output files\n",
    "# Replace slashes to avoid unintended subdirectories\n",
    "safe_model_name = hf_model_name.replace(\"/\", \"__\")\n",
    "safe_dataset_name = dataset_name.replace(\"/\", \"__\")\n",
    "\n",
    "results_folder = f\"./results/{safe_model_name}/\"\n",
    "\n",
    "# Ensure all necessary directories exist\n",
    "os.makedirs(results_folder, exist_ok=True)\n",
    "\n",
    "def save_results(res, subname):\n",
    "    df = pd.DataFrame(res).T\n",
    "    df.index.name = \"question\"\n",
    "    df.to_csv(f\"{results_folder}{safe_dataset_name}_{subname}.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e8fe8ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(local_dataset_dir, dataset_name, dataset_subset, dataset_split, hf_cache_dir, max_examples)\n",
    "\n",
    "# Load model and tokenizer\n",
    "model, tokenizer = get_model_and_tokenizer(local_model_dir, hf_model_name, hf_cache_dir, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a3fbee",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = min(dataset.shape[0], 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f33871be",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = defaultdict(dict)\n",
    "exceptions = defaultdict(dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6575a450",
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_sizes = [3, 5, 7, 9]\n",
    "beam_sizes = [5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddf3b7fa-4488-4297-b8a7-c05f1209cbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_proposed = True\n",
    "\n",
    "if run_proposed:\n",
    "    \n",
    "    proposed_configurations = {\n",
    "        f\"proposed_{beam_size}\": beam_size for beam_size in beam_sizes\n",
    "    }\n",
    "    \n",
    "    for method_name, max_beam_size in proposed_configurations.items():\n",
    "    \n",
    "        proposed_info = {}\n",
    "        print(f\"Running {method_name}..\")\n",
    "\n",
    "        for idx, example in enumerate(tqdm(dataset, desc=\"Examples\")):    \n",
    "            prompt = prompt_generator(example)\n",
    "            \n",
    "            gt_answer = example.get(answer_field, \"N/A\")\n",
    "            \n",
    "            if prompt not in results:           \n",
    "                results[prompt] = {\n",
    "                    \"ground_truth\": gt_answer\n",
    "                }\n",
    "                \n",
    "            inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "            input_ids = inputs[\"input_ids\"]\n",
    "\n",
    "            try:\n",
    "                tree = entropy_guided_beam_search(model, tokenizer, input_ids,\n",
    "                                                max_new_tokens=max_new_tokens,\n",
    "                                                max_beam_size=max_beam_size)\n",
    "\n",
    "                proposed_answer = format_answer(tokenizer, tree.best_path(eos_token_id=tokenizer.eos_token_id)[-1].context)\n",
    "                \n",
    "                proposed_info[prompt] = {\n",
    "                    \"mean_branching_factor\": tree.mean_branching_factor(),\n",
    "                    \"total_branches\": tree.total_branches(),\n",
    "                }\n",
    "                        \n",
    "                for k, v in tree.beam_sizes.items():\n",
    "                    proposed_info[prompt][f\"beam_{k}\"] = v\n",
    "                    \n",
    "                del tree\n",
    "                \n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                proposed_answer = \"NA\"\n",
    "                exceptions[prompt][method_name] = str(e)\n",
    "\n",
    "            results[prompt][method_name] = proposed_answer\n",
    "                    \n",
    "            # Clear intermediate variables\n",
    "            del inputs, input_ids\n",
    "            \n",
    "            if idx % batch_size == 0:\n",
    "                save_results(results, \"results\")\n",
    "                save_results(exceptions, \"exceptions\")\n",
    "                save_results(proposed_info, f\"{method_name}_info\")\n",
    "                \n",
    "        save_results(results, \"results\")\n",
    "        save_results(exceptions, \"exceptions\")\n",
    "        save_results(proposed_info, f\"{method_name}_info\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9056fa75",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_comparisons = True\n",
    "\n",
    "# Compare with baseline decoding strategies\n",
    "params = {\n",
    "    \"top_k\": 10,\n",
    "    \"top_p\": 0.1,\n",
    "    \"min_p\": 0.9,\n",
    "    \"num_beams\": 3,\n",
    "    \"greedy\": True,\n",
    "}\n",
    "\n",
    "#params = [\n",
    "#    (f\"num_beams_{beam_size}\", beam_size) for beam_size in beam_sizes\n",
    "#]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea534a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_gen_config = get_model_params(local_model_dir)\n",
    "    \n",
    "# Override with values from the config if they exist\n",
    "for key in list(params.keys()):\n",
    "    if key in raw_gen_config:\n",
    "        val = raw_gen_config[key]\n",
    "        print(\"Setting\", key, \"to\", val)\n",
    "        raw_gen_config[key] = val\n",
    "        \n",
    "comparisons = list(params.items())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ef852b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_comparisons:\n",
    "    \n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size)\n",
    "    \n",
    "    # Loop over batches\n",
    "    for batch in tqdm(dataloader, desc=\"Batches\"):\n",
    "        prompts = [prompt_generator({k: batch[k][i] for k in batch}) for i in range(len(batch[text_field]))]\n",
    "        gt_answers = batch.get(answer_field, [\"N/A\"] * len(prompts))\n",
    "        \n",
    "        # Tokenize batch\n",
    "        inputs = tokenizer(\n",
    "            prompts,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=True,\n",
    "            truncation=True,\n",
    "        ).to(device)\n",
    "        input_ids = inputs[\"input_ids\"]\n",
    "        attention_mask = inputs[\"attention_mask\"]\n",
    "\n",
    "        # Initialize per-question entries\n",
    "        for i, q in enumerate(prompts):\n",
    "            if q not in results:\n",
    "                results[q] = {\n",
    "                    \"ground_truth\": gt_answers[i]\n",
    "                }\n",
    "\n",
    "        # Run all decoding strategies\n",
    "        for name, value in comparisons:\n",
    "            print(f\"Running comparison: {name}\")\n",
    "            # Deal with num_beams_x case\n",
    "            processed_name = \"num_beams\" if \"num_beams\" in name else name\n",
    "            kwargs = {processed_name: value}\n",
    "            try:\n",
    "                answers = compare_batch(\n",
    "                    model, tokenizer, input_ids, max_new_tokens=max_new_tokens, attention_mask=attention_mask, **kwargs\n",
    "                )\n",
    "            except Exception as e:\n",
    "                answers = [\"NA\"] * len(prompts)\n",
    "                print(\"Exception\", e)\n",
    "                for i, q in enumerate(prompts):\n",
    "                    exceptions[q][name] = str(e)\n",
    "\n",
    "            for i, q in enumerate(prompts):\n",
    "                results[q][name] = answers[i]\n",
    "\n",
    "        del inputs, input_ids\n",
    "        \n",
    "        cleanup()\n",
    "            \n",
    "        # Save as we go\n",
    "        save_results(results, \"results\")\n",
    "        save_results(exceptions, \"exceptions\")\n",
    "        \n",
    "    save_results(results, \"results\")\n",
    "    save_results(exceptions, \"exceptions\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
