{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4fb42e9-e513-4291-974b-da0e0ac86c8e",
   "metadata": {},
   "source": [
    "# Estimating the Unseen: improved Estimators for Entropy and other properties\n",
    "\n",
    "### Show entropy estimation as number of samples increases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df89cc7f-2146-4c01-ade5-b5bd4db094b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from math import log\n",
    "from estimators import make_finger, unseen\n",
    "from datasets import load_dataset\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from src.models.language_model import load_model_and_tokenizer\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7da0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mticker\n",
    "\n",
    "plt.style.use(\"bmh\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a459d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_model_name = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "local_model_dir = f\"./local/models/{hf_model_name}\"\n",
    "hf_cache_dir = \"./.hf_cache\"\n",
    "\n",
    "# Load model and tokenizer\n",
    "model, tokenizer = load_model_and_tokenizer(\n",
    "    model_name=local_model_dir,\n",
    "    device=device,\n",
    "    local_files_only=True,\n",
    "    cache_dir=hf_cache_dir,\n",
    ")\n",
    "\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e67c6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_prompts(dataset_name, dataset_subset, dataset_split, text_field) :\n",
    "    local_dataset_dir = f\"./local/data/{dataset_name}-{dataset_subset}\"\n",
    "    dataset = load_dataset(local_dataset_dir, dataset_subset, split=dataset_split, cache_dir=hf_cache_dir, trust_remote_code=True)\n",
    "    prompts = dataset[text_field]\n",
    "    del dataset\n",
    "    return prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c90bd7be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GSM8k Dataset\n",
    "dataset_name = \"openai/gsm8k\"\n",
    "dataset_subset = \"main\"\n",
    "dataset_split = \"test\"\n",
    "text_field = \"question\"\n",
    "answer_field = \"answer\"\n",
    "choices_field = None\n",
    "gsm8k = read_prompts(dataset_name, dataset_subset, dataset_split, text_field)\n",
    "\n",
    "\n",
    "# Math-500 Dataset\n",
    "dataset_name = \"HuggingFaceH4/MATH-500\"\n",
    "dataset_subset = \"default\"\n",
    "dataset_split = \"test\"\n",
    "text_field = \"problem\"         # or \"question\" depending what the JSON key is\n",
    "answer_field = \"answer\"  # ground truth answer field\n",
    "choices_field = None\n",
    "math500 = read_prompts(dataset_name, dataset_subset, dataset_split, text_field)\n",
    "\n",
    "# HumanEval Dataset\n",
    "dataset_name = \"openai/openai_humaneval\"\n",
    "dataset_subset = \"openai_humaneval\"\n",
    "dataset_split = \"test\"\n",
    "text_field = \"prompt\"         # or \"question\" depending what the JSON key is\n",
    "answer_field = \"canonical_solution\"  # ground truth answer field\n",
    "choices_field = None\n",
    "humaneval = read_prompts(dataset_name, dataset_subset, dataset_split, text_field)\n",
    "\n",
    "dataset_name = \"xw27/scibench\"\n",
    "dataset_subset = \"default\"\n",
    "dataset_split = \"train\"\n",
    "text_field = \"problem_text\"         # or \"question\" depending what the JSON key is\n",
    "answer_field = \"answer_number\"  # ground truth answer field\n",
    "choices_field = None\n",
    "\n",
    "scibench = read_prompts(dataset_name, dataset_subset, dataset_split, text_field)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "313bb8c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = gsm8k + math500 + humaneval + scibench"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dda8209-0339-4ea8-8331-a15b90c322ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_llm_probs(prompt):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    with torch.no_grad():\n",
    "        logits = model(**inputs).logits\n",
    "    next_token_logits = logits[0, -1, :]\n",
    "    probs = F.softmax(next_token_logits.float(), dim=0).cpu().numpy()\n",
    "    return probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5ea806b-c6a0-435b-abb2-8b70dea6b5d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_estimation(prompt, ks):\n",
    "    probs = get_llm_probs(prompt)\n",
    "    vocab_size = len(probs)\n",
    "    max_entropy = np.log(vocab_size)\n",
    "    true_entropy = -np.sum(probs * np.log(probs + 1e-12))\n",
    "    true_entropy_norm = true_entropy / max_entropy  # normalize\n",
    "\n",
    "    naive_errors = []\n",
    "    unseen_errors = []\n",
    "\n",
    "    for k in ks:\n",
    "        samples = np.random.choice(vocab_size, size=k, p=probs)\n",
    "        f = make_finger(samples)\n",
    "\n",
    "        # Naive entropy\n",
    "        empirical_probs = (np.arange(1, len(f) + 1)) / k\n",
    "        naive_entropy = -np.sum(f * empirical_probs * np.log(empirical_probs + 1e-12))\n",
    "        naive_entropy_norm = naive_entropy / max_entropy  # normalize\n",
    "\n",
    "        # Estimated entropy\n",
    "        histx, x = unseen(f)\n",
    "        estimated_entropy = -np.sum(histx * x * np.log(x + 1e-12))\n",
    "        estimated_entropy_norm = estimated_entropy / max_entropy  # normalize\n",
    "\n",
    "        naive_errors.append(np.abs(naive_entropy_norm - true_entropy_norm))\n",
    "        unseen_errors.append(np.abs(estimated_entropy_norm - true_entropy_norm))\n",
    "\n",
    "    return naive_errors, unseen_errors, true_entropy_norm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9b65d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_with_bootstrap_ci(ks, naive_stats_all, unseen_stats_all, n_boot=1000, ci=95, save_path=\"plots/entropy_estimation.pdf\"):\n",
    "    ks = np.array(ks)\n",
    "    naive = np.array(naive_stats_all)\n",
    "    unseen = np.array(unseen_stats_all)\n",
    "\n",
    "    if naive.ndim != 2 or unseen.ndim != 2:\n",
    "        raise ValueError(\"naive_stats_all and unseen_stats_all must be 2D arrays (n_reps, n_ks)\")\n",
    "    if naive.shape[1] != ks.shape[0] or unseen.shape[1] != ks.shape[0]:\n",
    "        raise ValueError(\"Second dimension of stats arrays must equal len(ks)\")\n",
    "\n",
    "    # sort ks (and reorder columns of stats accordingly)\n",
    "    order = np.argsort(ks)\n",
    "    ks_sorted = ks[order]\n",
    "    naive = naive[:, order]\n",
    "    unseen = unseen[:, order]\n",
    "\n",
    "    # use colorblind-safe palette\n",
    "    palette = sns.color_palette(\"colorblind\", 2)\n",
    "    col_naive, col_unseen = palette\n",
    "\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    plt.axhline(0, ls=\"--\", color=\"gray\", label=\"True\")\n",
    "\n",
    "    def _bootstrap_ci(data, n_boot, ci):\n",
    "        rng = np.random.default_rng()\n",
    "        n_reps, n_ks = data.shape\n",
    "        means = []\n",
    "        for _ in range(n_boot):\n",
    "            sample_idx = rng.integers(0, n_reps, size=n_reps)\n",
    "            sample = data[sample_idx]\n",
    "            means.append(sample.mean(axis=0))\n",
    "        means = np.array(means)\n",
    "        lower = np.percentile(means, (100 - ci) / 2, axis=0)\n",
    "        upper = np.percentile(means, 100 - (100 - ci) / 2, axis=0)\n",
    "        return data.mean(axis=0), lower, upper\n",
    "\n",
    "    def _plot_bootstrap(data, color, label, linestyle=\"-\"):\n",
    "        mean, lower, upper = _bootstrap_ci(data, n_boot, ci)\n",
    "        plt.plot(ks_sorted, mean, marker=\"o\", color=color, lw=2, ls=linestyle, label=label)\n",
    "        plt.fill_between(ks_sorted, lower, upper, color=color, alpha=0.25)\n",
    "        plt.plot(ks_sorted, lower, ls=\":\", lw=0.8, color=color, alpha=0.6)\n",
    "        plt.plot(ks_sorted, upper, ls=\":\", lw=0.8, color=color, alpha=0.6)\n",
    "\n",
    "    _plot_bootstrap(naive, color=col_naive, label=\"Naive\")\n",
    "    _plot_bootstrap(unseen, color=col_unseen, label=\"Unseen\")\n",
    "\n",
    "    # Add 0.5/B reference lines (label at right edge for clarity)\n",
    "    Bs = [3, 5, 10]\n",
    "    for B in Bs:\n",
    "        y = 0.5 / B\n",
    "        plt.axhline(y, color=\"gray\", linestyle=\":\", linewidth=1)\n",
    "        plt.text(ks_sorted[-1], y, r\"$B_{\\mathrm{max}}$=\" + f\"{B}\",\n",
    "                 color=\"black\", va=\"center\", ha=\"right\", fontsize=9)\n",
    "\n",
    "    ax = plt.gca()\n",
    "    ax.set_xscale('log')\n",
    "    ax.xaxis.set_major_formatter(mticker.ScalarFormatter())\n",
    "\n",
    "    plt.xlabel(\"Number of samples\", fontsize=13)\n",
    "    plt.ylabel(\"Entropy estimation error (ε)\", fontsize=13)\n",
    "    plt.legend(title=\"Estimator\", fontsize=11, title_fontsize=12)\n",
    "    plt.grid(alpha=0.2)\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_path is not None:\n",
    "        import os\n",
    "        os.makedirs(os.path.dirname(save_path), exist_ok=True)\n",
    "        base, _ = os.path.splitext(save_path)\n",
    "        plt.savefig(base + \".png\", dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d85e2da",
   "metadata": {},
   "outputs": [],
   "source": [
    "ks = [10, 100, 1000, 10000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5668694",
   "metadata": {},
   "outputs": [],
   "source": [
    "naive_stats_all = []\n",
    "unseen_stats_all = []  \n",
    "\n",
    "for prompt in prompts:  \n",
    "    naive_pred, unseen_pred, _ = get_estimation(prompt, ks)\n",
    "    naive_stats_all.append(naive_pred)\n",
    "    unseen_stats_all.append(unseen_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a64d3f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_with_bootstrap_ci(ks, naive_stats_all, unseen_stats_all)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a789e9f4",
   "metadata": {},
   "source": [
    "## Robustness of entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92cf4564-b38d-477e-937e-815f61516efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def get_logits_from_embed(embed):\n",
    "    outputs = model(inputs_embeds=embed)\n",
    "    return outputs.logits[0, -1, :].float()\n",
    "\n",
    "\n",
    "def get_prob_and_entropy(logits):\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    entropy = -(probs * torch.log(probs + 1e-12)).sum(-1)\n",
    "    return probs, entropy\n",
    "\n",
    "@torch.no_grad()\n",
    "def estimate_entropy_vs_tv_batch(prompt, epsilon=1e-2, trials=10):\n",
    "    # Tokenize and get embeddings\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    input_ids = inputs[\"input_ids\"]\n",
    "    base_embed = model.get_input_embeddings()(input_ids)  # (1, seq_len, embed_dim)\n",
    "\n",
    "    # Base distribution\n",
    "    base_logits = get_logits_from_embed(base_embed)\n",
    "    base_probs, base_entropy = get_prob_and_entropy(base_logits)\n",
    "    base_probs = base_probs.unsqueeze(0)  # (1, vocab_size)\n",
    "    \n",
    "    # Create noisy embeddings in batch\n",
    "    noise = torch.randn(trials, *base_embed.shape[1:], device=device) * epsilon\n",
    "    perturbed_embeds = base_embed.expand(trials, -1, -1) + noise  # (trials, seq_len, embed_dim)\n",
    "\n",
    "    # Forward pass all perturbed embeddings in parallel\n",
    "    logits_batch = model(inputs_embeds=perturbed_embeds).logits[:, -1, :].float()  # (trials, vocab_size)\n",
    "\n",
    "    # Compute probabilities & entropies in batch\n",
    "    perturbed_probs = F.softmax(logits_batch, dim=-1)  # (trials, vocab_size)\n",
    "    perturbed_entropies = -(perturbed_probs * torch.log(perturbed_probs + 1e-12)).sum(-1)  # (trials,)\n",
    "    \n",
    "\n",
    "    # Compute TV distances in batch\n",
    "    tv = 0.5 * (base_probs - perturbed_probs).abs().sum(-1)  # (trials,)\n",
    "\n",
    "    # Entropy differences\n",
    "    entropy_diff = (base_entropy - perturbed_entropies).abs()  # (trials,)\n",
    "\n",
    "    # Convert to list of tuples (tv, entropy_diff)\n",
    "    results = list(zip(tv.tolist(), entropy_diff.tolist()))\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32b40043",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for prompt in prompts:\n",
    "    # Run experiment\n",
    "    prompt_results = estimate_entropy_vs_tv_batch(prompt, epsilon=1e-2, trials=1)\n",
    "    results.extend(prompt_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72532a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6, 5))\n",
    "\n",
    "# Filter out very small TV distances\n",
    "filtered_results = [(tv, ed) for tv, ed in results if tv > 0.01]\n",
    "tvs, entropy_diffs = zip(*filtered_results)\n",
    "\n",
    "tvs = np.array(tvs)\n",
    "entropy_diffs = np.array(entropy_diffs)\n",
    "\n",
    "# Determine dataset of each point by index\n",
    "labels = []\n",
    "cumulative_lengths = np.cumsum([len(gsm8k), len(math500), len(humaneval), len(scibench)])\n",
    "dataset_names = [\"gsm8k\", \"math500\", \"humaneval\", \"scibench\"]\n",
    "\n",
    "for i in range(len(results)):\n",
    "    if results[i][0] <= 0.01:\n",
    "        continue\n",
    "    idx = i\n",
    "    for j, cl in enumerate(cumulative_lengths):\n",
    "        if idx < cl:\n",
    "            labels.append(dataset_names[j])\n",
    "            break\n",
    "\n",
    "# Map datasets to colormap\n",
    "cmap = plt.get_cmap(\"Set2\")  # or \"tab10\", \"tab20\"\n",
    "color_map = {name: cmap(i) for i, name in enumerate(dataset_names)}\n",
    "colors = [color_map[label] for label in labels]\n",
    "\n",
    "# Scatter plot\n",
    "plt.scatter(tvs, entropy_diffs, alpha=0.5, s=20, c=colors)\n",
    "\n",
    "# Legend\n",
    "for name, color in color_map.items():\n",
    "    plt.scatter([], [], c=color, label=name)\n",
    "plt.legend()\n",
    "\n",
    "# O(δ²) reference curve\n",
    "safe_tvs = np.where(tvs < 1e-12, 1e-12, tvs)\n",
    "Ks = entropy_diffs / (safe_tvs ** 2)\n",
    "K = np.max(Ks)\n",
    "\n",
    "delta_vals = np.linspace(np.min(tvs), np.max(tvs), 100)\n",
    "scaled_ref_curve = K * delta_vals**2\n",
    "plt.plot(delta_vals, scaled_ref_curve, '--',\n",
    "         color=\"black\", label=r\"$\\mathcal{O}(\\delta^2)$ Upper bound\")\n",
    "\n",
    "# Annotation\n",
    "x_annot = delta_vals[int(len(delta_vals)*0.8)]\n",
    "y_annot = scaled_ref_curve[int(len(delta_vals)*0.8)]\n",
    "plt.text(x_annot, y_annot * 0.1, r\"$\\Delta H \\leq \\mathcal{O}(\\delta^2)$\",\n",
    "         fontsize=12, ha=\"center\", va=\"bottom\")\n",
    "\n",
    "plt.xlabel(r\"Total Variation Distance $\\delta$\")\n",
    "plt.ylabel(r\"Entropy Change $|\\bar{H}(P) - \\bar{H}(P')|$\")\n",
    "plt.title(\"Entropy Robustness to Input Perturbation\")\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.grid(0.25)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"plots/entropy_vs_tv_colored.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
