{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "from typing import List, Tuple, Dict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_jsonl(path: str) -> pd.DataFrame:\n",
    "    df = pd.read_json(path, lines=True)\n",
    "    assert {\"gt\", \"top_ids\", \"top_probs\"}.issubset(df.columns), \\\n",
    "        \"Input must have keys: gt, top_ids, top_probs\"\n",
    "    # Convert top_probs to float if it's a string\n",
    "    if df[\"top_probs\"].dtype == object:\n",
    "        df[\"top_probs\"] = df[\"top_probs\"].apply(lambda x: [float(p) for p in x])\n",
    "    return df\n",
    "\n",
    "def print_last_k_lists(acc: np.ndarray, labels: List[str], support: np.ndarray, thresholds=(0.10, 0.05, 0.01)):\n",
    "    def right_edge(label: str) -> float:\n",
    "        base = label.split(\"(\")[0].strip()\n",
    "        sep = \"–\" if \"–\" in base else \"-\"\n",
    "        parts = [p.strip() for p in base.split(sep)]\n",
    "        try:\n",
    "            return float(parts[-1])\n",
    "        except Exception:\n",
    "            nums = re.findall(r\"[-+]?\\d*\\.\\d+|\\d+\", base)\n",
    "            return float(nums[-1]) if nums else np.nan\n",
    "\n",
    "    rights = np.array([right_edge(s) for s in labels])\n",
    "    order = np.argsort(rights)[::-1]\n",
    "\n",
    "    for thr in thresholds:\n",
    "        result = []\n",
    "        for b in order:\n",
    "            row = acc[b, :]\n",
    "            mask = (~np.isnan(row)) & (row > thr)\n",
    "            idxs = np.flatnonzero(mask)\n",
    "            if support[b] < 10:\n",
    "                result.append(1)\n",
    "                continue\n",
    "            result.append(max(1, int(idxs[-1] + 1) if idxs.size > 0 else 0))\n",
    "        print(f\"last_top_k_where_c_gt_{thr:.2f} = [{','.join(map(str, result))}]\")\n",
    "\n",
    "def make_bins_top1(\n",
    "    top1: np.ndarray,\n",
    "    bins: int = 10,\n",
    "    quantile: bool = False,\n",
    ") -> Tuple[np.ndarray, List[str], np.ndarray]:\n",
    "    if quantile:\n",
    "        cat = pd.qcut(top1, q=bins, duplicates=\"drop\")\n",
    "        codes = cat.cat.codes.to_numpy()\n",
    "        intervals = cat.cat.categories\n",
    "        labels = [f\"{iv.left:.2f}–{iv.right:.2f}\" for iv in intervals]\n",
    "        edges = np.array([iv.left for iv in intervals] + [intervals[-1].right])\n",
    "        return codes, labels, edges\n",
    "    else:\n",
    "        edges = np.linspace(0.0, 1.0, bins + 1)\n",
    "        codes = np.digitize(top1, edges, right=True) - 1\n",
    "        codes = np.clip(codes, 0, bins - 1)\n",
    "        labels = [f\"{edges[i]:.1f}–{edges[i+1]:.1f}\" for i in range(bins)]\n",
    "        return codes, labels, edges\n",
    "\n",
    "def aggregate_rank_by_bin_conditional(\n",
    "    df: pd.DataFrame,\n",
    "    K: int,\n",
    "    bins: int = 10,\n",
    "    quantile_bins: bool = False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Computes:\n",
    "      acc[b, r]      = \\bar c_{m,r}  = P(gold at rank r | bin b)\n",
    "      mu_cond[b, r]  = \\tilde p_{m,r}= E[p^(r) | bin b, gold at rank r]   <-- conditional mean\n",
    "      std_cond[b, r] = std of p^(r) over the same conditional subset\n",
    "      (also returns unconditional means for reference)\n",
    "      C_topK[b]      = sum_r mu_cond[b,r] * acc[b,r]  (bin accuracy restricted to gold ∈ top-K)\n",
    "    \"\"\"\n",
    "    top_ids = np.stack(df[\"top_ids\"].to_numpy())\n",
    "    top_probs = np.stack(df[\"top_probs\"].to_numpy())\n",
    "    assert top_ids.shape == top_probs.shape, \"top_ids/top_probs shape mismatch\"\n",
    "\n",
    "    if K > top_probs.shape[1]:\n",
    "        raise ValueError(f\"K={K} > available top list length {top_probs.shape[1]}\")\n",
    "\n",
    "    top_ids = top_ids[:, :K]\n",
    "    top_probs = top_probs[:, :K]\n",
    "    top1 = top_probs[:, 0]\n",
    "    gt = df[\"gt\"].to_numpy()\n",
    "\n",
    "    codes, labels, _ = make_bins_top1(top1, bins=bins, quantile=quantile_bins)\n",
    "    B = len(labels)\n",
    "\n",
    "    # rank-wise correctness (same as your current acc)\n",
    "    acc = np.full((B, K), np.nan)          # \\bar c_{m,r}\n",
    "    support = np.zeros(B, dtype=int)       # N_m\n",
    "\n",
    "    # NEW: conditional means (tilde p), and also keep unconditional for reference\n",
    "    mu_cond = np.full((B, K), np.nan)      # \\tilde p_{m,r}\n",
    "    std_cond = np.full((B, K), np.nan)\n",
    "    mu_uncond = np.full((B, K), np.nan)    # \\bar p_{m,r}\n",
    "    std_uncond = np.full((B, K), np.nan)\n",
    "\n",
    "    # Optional: counts per (bin, rank) where gold is at rank r\n",
    "    N_mr = np.zeros((B, K), dtype=int)\n",
    "\n",
    "    for b in range(B):\n",
    "        idx = (codes == b)\n",
    "        n = int(idx.sum())\n",
    "        support[b] = n\n",
    "        if n == 0:\n",
    "            continue\n",
    "\n",
    "        probs_b = top_probs[idx, :]      # [n, K]\n",
    "        ids_b = top_ids[idx, :]          # [n, K]\n",
    "        gt_b = gt[idx]                   # [n]\n",
    "\n",
    "        # Unconditional means (what you had before)\n",
    "        mu_uncond[b, :] = probs_b.mean(axis=0)\n",
    "        std_uncond[b, :] = probs_b.std(axis=0, ddof=0)\n",
    "\n",
    "        # Rank-wise correctness and counts\n",
    "        # mask[t, r] == True iff gold is at rank r for example t\n",
    "        mask = (ids_b == gt_b[:, None])  # [n, K] bool\n",
    "        counts = mask.sum(axis=0)        # [K]\n",
    "        N_mr[b, :] = counts\n",
    "        acc[b, :] = counts / n\n",
    "\n",
    "        # CONDITIONAL means: average p^(r) ONLY over examples where gold is at rank r\n",
    "        for r in range(K):\n",
    "            if counts[r] > 0:\n",
    "                vals = probs_b[mask[:, r], r]   # select p^(r) where gold-at-r\n",
    "                mu_cond[b, r] = vals.mean()\n",
    "                std_cond[b, r] = vals.std(ddof=0)\n",
    "            else:\n",
    "                # leave as NaN; contribution to sums will be zero because acc[b,r]=0\n",
    "                pass\n",
    "\n",
    "    # Bin accuracy restricted to top-K (sanity check / proxy):\n",
    "    # C_topK[b] = sum_r \\tilde p_{m,r} * \\bar c_{m,r}\n",
    "    C_topK = np.nansum(mu_cond * acc, axis=1)  # shape [B]\n",
    "\n",
    "    return (\n",
    "        acc,            # \\bar c_{m,r}\n",
    "        mu_cond,        # \\tilde p_{m,r}  (use this for conditioned calculations)\n",
    "        std_cond,\n",
    "        support,\n",
    "        labels,\n",
    "        # extras (optional, but handy to inspect)\n",
    "        mu_uncond,      # \\bar p_{m,r}    (your previous \"mu\")\n",
    "        std_uncond,\n",
    "        N_mr,\n",
    "        C_topK\n",
    "    )\n",
    "\n",
    "# --- New code for the dual-y plot you asked for ---\n",
    "def compute_expected_accuracy(mu: np.ndarray, acc: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    expected accuracy per bin = sum_k (mu_k * c_k)\n",
    "    Handles NaNs by treating missing entries as 0 contribution.\n",
    "    \"\"\"\n",
    "    mu_safe = np.nan_to_num(mu, nan=0.0)\n",
    "    sum_mu = mu_safe.sum(axis=1)\n",
    "    mu_safe = [mu_safe[i]/sum_mu[i] for i in range(mu_safe.shape[0])]\n",
    "    acc_safe = np.nan_to_num(acc, nan=0.0)\n",
    "    return (mu_safe * acc_safe).sum(axis=1)  # shape: [B]\n",
    "\n",
    "def calculate_expected_acc_and_freq_data(\n",
    "    model_specs: List[Dict[str, str]],\n",
    "    K: int = 10,\n",
    "    bins: int = 10,\n",
    "    quantile_bins: bool = False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Calculate expected accuracy and frequency data for plotting.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    model_specs : list of {\"label\": str, \"path\": str}\n",
    "        e.g., [{\"label\":\"0.5B\",\"path\":\"/path/a.jsonl\"}, ...]\n",
    "    K : int\n",
    "        Use top-K from each file.\n",
    "    bins : int\n",
    "        Number of confidence bins (default 10 for 0.0–0.1 ... 0.9–1.0).\n",
    "    quantile_bins : bool\n",
    "        If True, use quantile bins; else fixed 0..1 bins.\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    dict with keys:\n",
    "        - 'all_expected': List[np.ndarray] - expected accuracy for each model\n",
    "        - 'all_freq_pct': List[np.ndarray] - frequency percentages for each model\n",
    "        - 'x_labels_common': List[str] - bin labels\n",
    "        - 'model_labels': List[str] - model labels\n",
    "    \"\"\"\n",
    "    # Storage\n",
    "    all_expected = []\n",
    "    all_freq_pct = []\n",
    "    x_labels_common = None\n",
    "    model_labels = []\n",
    "\n",
    "    # Process each model\n",
    "    for spec in model_specs:\n",
    "        label, path = spec[\"label\"], spec[\"path\"]\n",
    "        df = load_jsonl(path)\n",
    "        acc, mu, std, support, labels, _, _, _, _ = aggregate_rank_by_bin_conditional(df, K=K, bins=bins, quantile_bins=quantile_bins)\n",
    "\n",
    "        # expected accuracy per bin: sum_k mu_k * c_k\n",
    "        expected_acc = compute_expected_accuracy(mu[:, :K], acc[:, :K])  # [B]\n",
    "\n",
    "        # frequency per bin (rounded to 4 dp), then percentage scale\n",
    "        total = support.sum()\n",
    "        freq = (support / total) if total > 0 else np.zeros_like(support, dtype=float)\n",
    "        freq = np.round(freq, 4)  # to 4 dp\n",
    "        freq_pct = 100.0 * freq\n",
    "\n",
    "        all_expected.append(expected_acc)\n",
    "        all_freq_pct.append(freq_pct)\n",
    "        model_labels.append(label)\n",
    "\n",
    "        # Capture/verify x labels\n",
    "        if x_labels_common is None:\n",
    "            x_labels_common = labels\n",
    "        else:\n",
    "            if labels != x_labels_common:\n",
    "                raise ValueError(\"Bin labels differ across models; ensure consistent binning settings.\")\n",
    "\n",
    "        # Optional: print frequencies for visibility/debugging\n",
    "        print(f\"[{label}] frequency per bin (proportion, 4 dp): {freq.tolist()}\")\n",
    "\n",
    "    return {\n",
    "        'all_expected': all_expected,\n",
    "        'all_freq_pct': all_freq_pct,\n",
    "        'x_labels_common': x_labels_common,\n",
    "        'model_labels': model_labels\n",
    "    }\n",
    "\n",
    "def plot_expected_acc_and_freq(\n",
    "    data: Dict,\n",
    "    figsize: Tuple[int, int] = (12, 6),\n",
    "    save_path: str | None = None,\n",
    "    colors: List[str] | None = None,\n",
    "    title: str = \"\",\n",
    "    show_grid: bool = True,\n",
    "    grid_alpha: float = 0.2,\n",
    "    linewidth: float = 1.0,\n",
    "    marker_size: int = 6,\n",
    "):\n",
    "    \"\"\"\n",
    "    Plot expected accuracy and frequency data.\n",
    "\n",
    "    - X-axis ticks at 0.0, 0.1, ..., 1.0\n",
    "    - Data points centered at bin midpoints\n",
    "    - Dotted/solid line style legend placed on the left next to the model size legend\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    data : dict\n",
    "        Output from calculate_expected_acc_and_freq_data()\n",
    "    figsize : (w, h)\n",
    "    save_path : str or None\n",
    "        If provided, saves the figure; otherwise shows it.\n",
    "    colors : List[str] or None\n",
    "        Custom colors for models. If None, uses darker versions of default color cycle.\n",
    "    title : str\n",
    "        Plot title\n",
    "    show_grid : bool\n",
    "        Whether to show grid\n",
    "    grid_alpha : float\n",
    "        Grid transparency\n",
    "    linewidth : float\n",
    "        Line width for plots\n",
    "    marker_size : int\n",
    "        Size of markers (not used when markers are disabled)\n",
    "    \"\"\"\n",
    "    all_expected = data['all_expected']\n",
    "    all_freq_pct = data['all_freq_pct']\n",
    "    x_labels_common = data['x_labels_common']\n",
    "    model_labels = data['model_labels']\n",
    "\n",
    "    # Colors (darker versions of default colors)\n",
    "    if colors is None:\n",
    "        color_cycle = plt.cm.tab10.colors  # up to 10 distinct colors\n",
    "        # Make colors darker by reducing brightness\n",
    "        colors = []\n",
    "        for i in range(len(model_labels)):\n",
    "            base_color = color_cycle[i % len(color_cycle)]\n",
    "            # Convert to darker version by scaling RGB values\n",
    "            darker_color = tuple(c * 0.7 for c in base_color[:3])  # 0.7 makes it darker\n",
    "            colors.append(darker_color)\n",
    "\n",
    "    num_bins = len(x_labels_common)\n",
    "    # Bin midpoints in [0, 1]\n",
    "    x_mid = (np.arange(num_bins) + 0.5) / num_bins\n",
    "\n",
    "    # X ticks exactly at 0., 0.1, ..., 1.0 (match two-subplot style)\n",
    "    x_ticks = np.linspace(0.0, 1.0, 11)\n",
    "    x_tick_labels = []\n",
    "    for t in x_ticks:\n",
    "        if np.isclose(t, 0.0):\n",
    "            x_tick_labels.append(\"0.\")\n",
    "        elif np.isclose(t, 1.0):\n",
    "            x_tick_labels.append(\"1.0\")\n",
    "        else:\n",
    "            x_tick_labels.append(f\"{t:.1f}\")\n",
    "\n",
    "    # --- Plot ---\n",
    "    fig, ax_left = plt.subplots(figsize=figsize)\n",
    "    ax_right = ax_left.twinx()\n",
    "\n",
    "    # Left y-axis: expected accuracy (0..1), dashed lines\n",
    "    for i, model_label in enumerate(model_labels):\n",
    "        ax_left.plot(\n",
    "            x_mid, all_expected[i],\n",
    "            linestyle=(0, (3, 2)),  # dashed\n",
    "            linewidth=linewidth, color=colors[i],\n",
    "            label=f\"{model_label} (expected acc)\"\n",
    "        )\n",
    "\n",
    "    ax_left.set_xlim(0.0, 1.0)\n",
    "    ax_left.set_ylim(0.0, 1.0)\n",
    "    ax_left.set_xticks(x_ticks)\n",
    "    ax_left.set_xticklabels(x_tick_labels)\n",
    "    ax_left.set_xlabel(\"Confidence bin\")\n",
    "    ax_left.set_ylabel(\"Expected accuracy\")\n",
    "\n",
    "    # Right y-axis: frequency in percent (0..100), solid lines\n",
    "    for i, model_label in enumerate(model_labels):\n",
    "        ax_right.plot(\n",
    "            x_mid, all_freq_pct[i],\n",
    "            linestyle=\"-\", linewidth=linewidth, color=colors[i],\n",
    "            label=f\"{model_label} (freq)\"\n",
    "        )\n",
    "\n",
    "    ax_right.set_ylim(0.0, 55.0)\n",
    "    ax_right.set_ylabel(\"Frequency of occurrences (%)\")\n",
    "\n",
    "    # Legends: (1) colors map to models, (2) line styles map to metrics\n",
    "    # Build model legend (colors)\n",
    "    model_handles = [Line2D([0], [0], color=colors[i], lw=3, label=model_labels[i]) for i in range(len(model_labels))]\n",
    "    # Build style legend\n",
    "    style_handles = [\n",
    "        Line2D([0], [0], color=\"black\", lw=2, linestyle=(0, (3, 2)), label=\"Expected accuracy\"),\n",
    "        Line2D([0], [0], color=\"black\", lw=2, linestyle=\"-\", label=\"Frequency\"),\n",
    "    ]\n",
    "\n",
    "    # Place both legends left, side-by-side\n",
    "    first_legend = ax_left.legend(handles=model_handles, title=\"Model size\", loc=\"upper left\", bbox_to_anchor=(0.03, 0.98), borderaxespad=0.0)\n",
    "    ax_left.add_artist(first_legend)\n",
    "    ax_left.legend(handles=style_handles, loc=\"upper left\", bbox_to_anchor=(0.2, 0.98), borderaxespad=0.0, title=None)\n",
    "\n",
    "    if show_grid:\n",
    "        ax_left.grid(True, axis=\"y\", alpha=grid_alpha)\n",
    "    plt.title(title)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_path:\n",
    "        plt.savefig(save_path, dpi=200, bbox_inches=\"tight\")\n",
    "        print(f\"Saved figure to: {save_path}\")\n",
    "    else:\n",
    "        plt.show()\n",
    "\n",
    "def plot_expected_acc_and_freq_from_specs(\n",
    "    model_specs: List[Dict[str, str]],\n",
    "    K: int = 10,\n",
    "    bins: int = 10,\n",
    "    quantile_bins: bool = False,\n",
    "    figsize: Tuple[int, int] = (12, 6),\n",
    "    save_path: str | None = None,\n",
    "    **plot_kwargs\n",
    "):\n",
    "    \"\"\"\n",
    "    Convenience function that combines data calculation and plotting.\n",
    "    This is the original function behavior for backward compatibility.\n",
    "    \"\"\"\n",
    "    data = calculate_expected_acc_and_freq_data(\n",
    "        model_specs=model_specs,\n",
    "        K=K,\n",
    "        bins=bins,\n",
    "        quantile_bins=quantile_bins\n",
    "    )\n",
    "    plot_expected_acc_and_freq(data, figsize=figsize, save_path=save_path, **plot_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_to_plot = [\n",
    "    {\"label\": \"0.5B\",  \"path\": \"/selective_greedy/softmax_values/topk_traces_20250904_202648.jsonl\"},\n",
    "    {\"label\": \"1.5B\",  \"path\": \"/selective_greedy/softmax_values/topk_traces_20250904_202131.jsonl\"},\n",
    "    {\"label\": \"3B\",    \"path\": \"/selective_greedy/softmax_values/topk_traces_20250904_194548.jsonl\"},\n",
    "    {\"label\": \"7B\",   \"path\": \"/selective_greedy/softmax_values/topk_traces_20250904_190655.jsonl\"},\n",
    "    {\"label\": \"14B\",    \"path\": \"/selective_greedy/softmax_values/topk_traces_20250904_192814.jsonl\"},\n",
    "]\n",
    "\n",
    "# Step 1: Calculate the data (only need to run this once)\n",
    "print(\"Calculating data...\")\n",
    "data = calculate_expected_acc_and_freq_data(\n",
    "    model_specs=models_to_plot,\n",
    "    K=20,                 # or 20, etc.\n",
    "    bins=10,              # fixed bins: 0.0–0.1, ..., 0.9–1.0\n",
    "    quantile_bins=False,  # keep fixed edges as requested\n",
    ")\n",
    "\n",
    "# Step 2: Create the plot with updated styling\n",
    "custom_colors = [\"#0072B2\", \"#E69F00\", \"#009E73\", \"#D55E00\", \"#CC79A7\", \"#56B4E9\"]  # darker versions\n",
    "\n",
    "plot_expected_acc_and_freq(\n",
    "    data=data,  # Reuse the same data from above\n",
    "    figsize=(8, 4),\n",
    "    colors=custom_colors,\n",
    "    #title=\"Custom Styled Plot - Expected Accuracy vs Frequency\",\n",
    "    linewidth=1.2,  # slightly thicker than default but still thin\n",
    "    show_grid=True,\n",
    "    grid_alpha=0.1,\n",
    "    save_path=\"/selective_greedy/post_processing/figures/var_size.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example: Create a different style plot using the same data\n",
    "# You can run this cell multiple times with different parameters without recalculating data\n",
    "\n",
    "# Custom darker colors for a different look\n",
    "custom_colors = [\"#0072B2\", \"#E69F00\", \"#009E73\", \"#D55E00\", \"#CC79A7\", \"#56B4E9\"]  # darker versions\n",
    "\n",
    "plot_expected_acc_and_freq(\n",
    "    data=data,  # Reuse the same data from above\n",
    "    figsize=(8, 3.5),\n",
    "    colors=custom_colors,\n",
    "    #title=\"Custom Styled Plot - Expected Accuracy vs Frequency\",\n",
    "    linewidth=1.2,  # slightly thicker than default but still thin\n",
    "    show_grid=True,\n",
    "    grid_alpha=0.1,\n",
    "    save_path=\"var_size.pdf\"\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchtune",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
