{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a891d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams[\"text.usetex\"] = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "302926ef",
   "metadata": {},
   "source": [
    "## Tree visualisation example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46591634",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_llm_decoding_tree_topdown(\n",
    "    B: int,\n",
    "    T: int,\n",
    "    seed=None,\n",
    "    cmap='cividis',\n",
    "    bias_power=3,\n",
    "    eos_prob_at_T=0.3,\n",
    "    early_stopping_prob=0.1,\n",
    "    length_penalty_alpha=1.0,\n",
    "    output_file=\"llm_decoding_tree.pdf\",\n",
    "):\n",
    "    if seed is not None:\n",
    "        random.seed(seed)\n",
    "\n",
    "    # ===== 1. Build Graph =====\n",
    "    G = nx.DiGraph()\n",
    "    node_id = 0\n",
    "    G.add_node(node_id, e=random.random(), depth=0, max_depth=T, terminal=False)\n",
    "    queue = [node_id]\n",
    "    node_id += 1\n",
    "    ellipsis_positions = []\n",
    "\n",
    "    while queue:\n",
    "        parent = queue.pop(0)\n",
    "        depth = G.nodes[parent]['depth']\n",
    "        branch_max_depth = G.nodes[parent]['max_depth']\n",
    "\n",
    "        if depth >= branch_max_depth:\n",
    "            G.nodes[parent]['terminal'] = True\n",
    "            continue\n",
    "\n",
    "        e = G.nodes[parent]['e']\n",
    "        num_children = max(1, math.floor(B * e))\n",
    "\n",
    "        # Allow early stopping before max depth\n",
    "        if depth > 0 and random.random() < early_stopping_prob:\n",
    "            num_children = 0\n",
    "\n",
    "        child_ids = []\n",
    "        for _ in range(num_children):\n",
    "            remaining_depth = T - (depth + 1)\n",
    "            if remaining_depth <= 0:\n",
    "                child_max_depth = depth + 1\n",
    "            else:\n",
    "                r = random.random()\n",
    "                extra_depth = math.floor(remaining_depth * (r ** (1 / bias_power)))\n",
    "                child_max_depth = depth + 1 + extra_depth\n",
    "\n",
    "            G.add_node(\n",
    "                node_id,\n",
    "                e=random.random(),\n",
    "                depth=depth + 1,\n",
    "                max_depth=child_max_depth,\n",
    "                terminal=False\n",
    "            )\n",
    "            child_ids.append(node_id)\n",
    "            queue.append(node_id)\n",
    "            node_id += 1\n",
    "\n",
    "        if depth < branch_max_depth:\n",
    "            ellipsis_positions.append(parent)\n",
    "\n",
    "        # Assign probabilities to children\n",
    "        if child_ids:\n",
    "            probs = [random.random() for _ in child_ids]\n",
    "            probs.sort(reverse=True)  # bias left > right\n",
    "            allocation_fraction = random.uniform(0.5, 1.0)\n",
    "            max_sum = allocation_fraction\n",
    "            sum_probs = sum(probs)\n",
    "            probs = [(p / sum_probs) * max_sum for p in probs]\n",
    "\n",
    "            # Add noise for non-greedy surprises\n",
    "            probs = [max(0.01, p + random.uniform(-0.05, 0.05)) for p in probs]\n",
    "            sum_probs = sum(probs)\n",
    "            if sum_probs > 0:\n",
    "                probs = [(p / sum_probs) * max_sum for p in probs]\n",
    "\n",
    "            for c, p in zip(child_ids, probs):\n",
    "                G.add_edge(parent, c, prob=p)\n",
    "\n",
    "    # ===== 2. Labels =====\n",
    "    labels = {}\n",
    "    for n in G.nodes():\n",
    "        d = G.nodes[n]['depth']\n",
    "        term = G.nodes[n]['terminal']\n",
    "        if term and d < (T - 1):\n",
    "            labels[n] = \"EOS\"\n",
    "        elif d == (T - 1):\n",
    "            labels[n] = \"EOS\" if random.random() < eos_prob_at_T else f\"t{d}\"\n",
    "        else:\n",
    "            labels[n] = f\"t{d}\"\n",
    "\n",
    "    # ===== 3. Most probable EOS path with length penalty =====\n",
    "    def most_probable_eos_path(G, start, alpha=1.0):\n",
    "        best_path = []\n",
    "        best_score = -1\n",
    "        best_prob = 0\n",
    "\n",
    "        def dfs(node, path, prob):\n",
    "            nonlocal best_path, best_score, best_prob\n",
    "            children = list(G.successors(node))\n",
    "            if not children:\n",
    "                if labels[node] == \"EOS\":\n",
    "                    length_penalty = (len(path)) ** alpha\n",
    "                    score = prob / length_penalty\n",
    "                    if score > best_score:\n",
    "                        best_score = score\n",
    "                        best_prob = prob\n",
    "                        best_path = path[:]\n",
    "                return\n",
    "            for child in children:\n",
    "                dfs(child, path + [child], prob * G.edges[node, child]['prob'])\n",
    "\n",
    "        dfs(start, [start], 1.0)\n",
    "        return best_path, best_prob, best_score\n",
    "\n",
    "    best_path, _, _ = most_probable_eos_path(G, 0, alpha=length_penalty_alpha)\n",
    "    best_edges = list(zip(best_path[:-1], best_path[1:]))\n",
    "\n",
    "    # ===== 4. Identify all EOS-ending edges =====\n",
    "    eos_edges = set()\n",
    "    def dfs_collect_eos_edges(node, path):\n",
    "        children = list(G.successors(node))\n",
    "        if not children:\n",
    "            if labels[node] == \"EOS\":\n",
    "                eos_edges.update(zip(path[:-1], path[1:]))\n",
    "            return\n",
    "        for child in children:\n",
    "            dfs_collect_eos_edges(child, path + [child])\n",
    "\n",
    "    dfs_collect_eos_edges(0, [0])\n",
    "\n",
    "    # ===== 5. Layout =====\n",
    "    def hierarchy_pos(G, root):\n",
    "        pos = {}\n",
    "        layer_nodes = {}\n",
    "        for n, data in G.nodes(data=True):\n",
    "            layer_nodes.setdefault(data['depth'], []).append(n)\n",
    "        for depth, nodes in layer_nodes.items():\n",
    "            x_spacing = 1 / (len(nodes) + 1)\n",
    "            for i, n in enumerate(nodes):\n",
    "                pos[n] = ((i + 1) * x_spacing, -depth)\n",
    "        return pos\n",
    "\n",
    "    pos = hierarchy_pos(G, 0)\n",
    "\n",
    "    # ===== 6. Plot =====\n",
    "    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)\n",
    "    node_color = [G.nodes[n]['e'] for n in G.nodes()]\n",
    "\n",
    "    # Draw nodes\n",
    "    nx.draw_networkx_nodes(\n",
    "        G, pos,\n",
    "        node_size=500,\n",
    "        node_color=node_color,\n",
    "        cmap=plt.cm.get_cmap(cmap),\n",
    "        vmin=0, vmax=1,\n",
    "        edgecolors=\"black\",\n",
    "        linewidths=0.8,\n",
    "        ax=ax\n",
    "    )\n",
    "\n",
    "    # All non-EOS edges\n",
    "    other_edges = set(G.edges()) - eos_edges\n",
    "    nx.draw_networkx_edges(\n",
    "        G, pos, edgelist=list(other_edges - set(best_edges)),\n",
    "        alpha=0.3, width=0.8, arrows=True, arrowstyle='-|>', ax=ax\n",
    "    )\n",
    "\n",
    "    # EOS-ending paths in dark grey\n",
    "    nx.draw_networkx_edges(\n",
    "        G, pos, edgelist=list(eos_edges - set(best_edges)),\n",
    "        width=1.5, edge_color=\"black\", alpha=0.4, arrows=True, arrowstyle='-|>', ax=ax\n",
    "    )\n",
    "\n",
    "    # Best EOS path in bright red\n",
    "    nx.draw_networkx_edges(\n",
    "        G, pos, edgelist=best_edges,\n",
    "        width=2.5, edge_color=\"red\", alpha=0.9, arrows=True, arrowstyle='-|>', ax=ax\n",
    "    )\n",
    "\n",
    "    # Labels\n",
    "    edge_labels = {(u, v): f\"{G.edges[u, v]['prob']:.2f}\" for u, v in G.edges()}\n",
    "    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax)\n",
    "    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=6, ax=ax)\n",
    "\n",
    "    # Ellipsis markers\n",
    "    for n in ellipsis_positions:\n",
    "        x, y = pos[n]\n",
    "        dot_spacing = 0.015\n",
    "        for offset in [-dot_spacing, 0, dot_spacing]:\n",
    "            ax.scatter(\n",
    "                x + 0.06 + offset, y,\n",
    "                s=8, c=\"#777777\", alpha=0.25, marker=\"o\", zorder=3\n",
    "            )\n",
    "\n",
    "    # Colorbar\n",
    "    sm = plt.cm.ScalarMappable(cmap=plt.cm.get_cmap(cmap), norm=plt.Normalize(vmin=0, vmax=1))\n",
    "    cbar = plt.colorbar(sm, ax=ax)\n",
    "    cbar.set_label('Conditional Entropy of Node', rotation=270, labelpad=15)\n",
    "\n",
    "    ax.set_axis_off()\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(output_file, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Example usage\n",
    "plot_llm_decoding_tree_topdown(\n",
    "    B=4, T=6, seed=42, bias_power=4, eos_prob_at_T=0.4,\n",
    "    early_stopping_prob=0.1, length_penalty_alpha=1.0,\n",
    "    output_file=\"llm_decoding_tree.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55c21c03",
   "metadata": {},
   "source": [
    "## Monte Carlo Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e138a866",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(7)\n",
    "\n",
    "# Simulation knobs\n",
    "T = 1200\n",
    "K = 25\n",
    "M_total = T * 5\n",
    "runs = 30\n",
    "\n",
    "T_min = 0.12\n",
    "T_max = 1.10\n",
    "\n",
    "def softmax_temp_gap(K, gap, temp):\n",
    "    logits = np.zeros(K)\n",
    "    logits[0] = gap\n",
    "    z = (logits / temp)\n",
    "    e = np.exp(z - np.max(z))\n",
    "    return e / np.sum(e)\n",
    "\n",
    "def generate_distributions_softmax(T, K, variability, rng, gap=3.0):\n",
    "    center = 0.5 * (T_min + T_max)\n",
    "    spread = variability * 0.5 * (T_max - T_min)\n",
    "    lows = np.full(T, center - spread)\n",
    "    highs = np.full(T, center + spread)\n",
    "    temps = rng.uniform(low=lows, high=highs)\n",
    "    temps = np.clip(temps, T_min, T_max)\n",
    "\n",
    "    P = np.zeros((T, K))\n",
    "    for t in range(T):\n",
    "        p = softmax_temp_gap(K, gap=gap, temp=temps[t])\n",
    "        best_idx = rng.integers(K)\n",
    "        P[t] = np.roll(p, shift=best_idx)\n",
    "        \n",
    "    return P\n",
    "\n",
    "def step_entropy(p):\n",
    "    p = np.clip(p, 1e-12, 1.0)\n",
    "    return -np.sum(p * np.log(p))\n",
    "\n",
    "def allocate_samples(P, total_budget, mode=\"fixed\", min_per_step=1):\n",
    "    T, K = P.shape\n",
    "    if mode == \"fixed\":\n",
    "        m = np.full(T, total_budget // T, dtype=int)\n",
    "        rem = total_budget - m.sum()\n",
    "        if rem > 0:\n",
    "            m[:rem] += 1\n",
    "        return m\n",
    "    elif mode == \"adaptive\":\n",
    "        # allocate a floor to avoid starving easy steps\n",
    "        floor_total = min_per_step * T\n",
    "        floor_total = min(floor_total, total_budget)  # guard\n",
    "        m = np.full(T, min_per_step, dtype=int)\n",
    "        remaining = total_budget - floor_total\n",
    "        if remaining <= 0:\n",
    "            return m\n",
    "        H = np.array([step_entropy(P[t]) for t in range(T)])\n",
    "        H_norm = H / np.log(K)\n",
    "        w = H_norm + 1e-9\n",
    "        w /= w.sum()\n",
    "        inc = np.floor(w * remaining).astype(int)\n",
    "        m += inc\n",
    "        diff = total_budget - m.sum()\n",
    "        if diff > 0:\n",
    "            idx = np.argsort(-(w - inc / remaining))[:diff]\n",
    "            m[idx] += 1\n",
    "        elif diff < 0:\n",
    "            idx = np.argsort(w - inc / remaining)[:abs(diff)]\n",
    "            m[idx] -= 1\n",
    "        return m\n",
    "    else:\n",
    "        raise ValueError(\"Unknown allocation mode\")\n",
    "\n",
    "def simulate_one(P, m, rng):\n",
    "    T, K = P.shape\n",
    "    regret = 0.0\n",
    "    for t in range(T):\n",
    "        p = P[t]\n",
    "        best = np.argmax(p)  # true best category\n",
    "        mt = m[t]\n",
    "        if mt <= 0:\n",
    "            chosen = rng.integers(K)\n",
    "        else:\n",
    "            samples = rng.choice(K, size=mt, p=p)\n",
    "            counts = np.bincount(samples, minlength=K)\n",
    "            chosen = np.argmax(counts)\n",
    "        # 0 if exact match, 1 if not\n",
    "        if chosen != best:\n",
    "            regret += 1\n",
    "    return regret\n",
    "\n",
    "# Sweep variability and compare policies\n",
    "variabilities = np.linspace(0.5, 1.0, 5)\n",
    "reg_fixed_mean, reg_adapt_mean = [], []\n",
    "reg_fixed_std, reg_adapt_std = [], []\n",
    "\n",
    "for v in variabilities:\n",
    "    fixed_runs, adapt_runs = [], []\n",
    "    for _ in range(runs):\n",
    "        P = generate_distributions_softmax(T, K, v, rng)\n",
    "        m_fixed = allocate_samples(P, M_total, mode=\"fixed\")\n",
    "        m_adapt = allocate_samples(P, M_total, mode=\"adaptive\", min_per_step=1)\n",
    "        assert m_adapt.sum() == m_fixed.sum() == M_total\n",
    "        r_fixed = simulate_one(P, m_fixed, rng)\n",
    "        r_adapt = simulate_one(P, m_adapt, rng)\n",
    "        fixed_runs.append(r_fixed)\n",
    "        adapt_runs.append(r_adapt)\n",
    "    reg_fixed_mean.append(np.mean(fixed_runs))\n",
    "    reg_adapt_mean.append(np.mean(adapt_runs))\n",
    "    reg_fixed_std.append(np.std(fixed_runs))\n",
    "    reg_adapt_std.append(np.std(adapt_runs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab951c1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5, 3))\n",
    "plt.plot(variabilities, reg_fixed_mean, 'o-', label='Fixed per step', color=\"orange\")\n",
    "plt.plot(variabilities, reg_adapt_mean, 's-', label='Entropy-adaptive', color=\"purple\")\n",
    "\n",
    "f = np.array(reg_fixed_mean)\n",
    "a = np.array(reg_adapt_mean)\n",
    "fs = np.array(reg_fixed_std)\n",
    "as_ = np.array(reg_adapt_std)\n",
    "\n",
    "plt.fill_between(variabilities, f - fs, f + fs, alpha=0.15, color=\"orange\")\n",
    "plt.fill_between(variabilities, a - as_, a + as_, alpha=0.15, color=\"purple\")\n",
    "\n",
    "# Replace axis labels with just \"low\" and \"high\"\n",
    "plt.xticks([variabilities.min(), variabilities.max()], ['Low', 'High'])\n",
    "plt.yticks([min(f.min(), a.min()), max(f.max(), a.max())], ['Low', 'High'])\n",
    "\n",
    "# Labels\n",
    "plt.xlabel('Var($H_t$)', fontsize=16)\n",
    "plt.ylabel('Total regret', fontsize=16)\n",
    "\n",
    "# Remove spines\n",
    "for spine in plt.gca().spines.values():\n",
    "    spine.set_visible(False)\n",
    "\n",
    "# Grid\n",
    "plt.grid(True, ls='--', lw=0.5, alpha=0.25)\n",
    "\n",
    "# Legend above plot\n",
    "plt.legend(\n",
    "    title=f\"M={M_total}\", \n",
    "    ncols=2, \n",
    "    frameon=False, \n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.5, 1.2)\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"monte_carlo.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
