{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a0b76d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import multiprocessing as mp\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "from analyze_traces import analyze\n",
    "\n",
    "YELLOW = \"\\033[33m\"\n",
    "BOLD = \"\\033[1m\"\n",
    "RESET = \"\\033[0m\"\n",
    "\n",
    "script_dir = os.getcwd()\n",
    "plot_dir = os.path.join(script_dir, \"..\", \"plots\")\n",
    "if not os.path.exists(plot_dir):\n",
    "    os.makedirs(plot_dir)\n",
    "\n",
    "trace_files = glob.glob(os.path.join(script_dir, \"input*.jsonl\"))\n",
    "trace_files.sort()\n",
    "\n",
    "results = []\n",
    "\n",
    "if trace_files:\n",
    "    worker_count = min(len(trace_files), mp.cpu_count() or 1)\n",
    "    if worker_count > 1:\n",
    "        ctx = mp.get_context(\"spawn\")\n",
    "        chunksize = max(1, len(trace_files) // (worker_count * 4))\n",
    "        with ctx.Pool(worker_count) as pool:\n",
    "            for bundles in tqdm(\n",
    "                pool.imap_unordered(analyze, trace_files, chunksize=chunksize),\n",
    "                total=len(trace_files),\n",
    "                desc=\"Analyzing\",\n",
    "            ):\n",
    "                results.extend(bundles)\n",
    "    else:\n",
    "        for bundles in tqdm(map(analyze, trace_files), total=len(trace_files), desc=\"Analyzing\"):\n",
    "            results.extend(bundles)\n",
    "\n",
    "df = pd.DataFrame(results)\n",
    "df = df[~df[\"warmup\"]].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a719dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def safename(name: str) -> str:\n",
    "    return name.replace('.', '_').replace('%', 'pct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "571db666",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "metric_preferences = {'gen_throughput': 'max', 'vsr': 'max', 'latency': 'min'}\n",
    "\n",
    "\n",
    "def report_minedraft_improvement(metrics_slice: pd.DataFrame, targets: list[str], context: str):\n",
    "    base_slice = metrics_slice[~metrics_slice['method'].str.startswith('MineDraft')]\n",
    "    std_slice = metrics_slice[metrics_slice['method'].str.startswith('Standard SD')]\n",
    "    mine_slice = metrics_slice[\n",
    "        metrics_slice['method'].str.startswith('MineDraft')\n",
    "        & ~metrics_slice['method'].str.contains(r'\\+')\n",
    "    ]\n",
    "    if std_slice.empty:\n",
    "        std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE')]\n",
    "        mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE')]\n",
    "    if std_slice.empty:\n",
    "        std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE-3')]\n",
    "        mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE-3')]\n",
    "\n",
    "    if base_slice.empty or std_slice.empty or mine_slice.empty:\n",
    "        return\n",
    "\n",
    "    std_sd = std_slice.groupby('k').first().dropna()\n",
    "    if std_sd.empty:\n",
    "        return\n",
    "\n",
    "    # 1) Max improvement over Std SD across ks\n",
    "    max_improvements = {target: -np.inf for target in targets}\n",
    "\n",
    "    # 2) Improvement at the k where baseline is best\n",
    "    mine_at_base_k = {\n",
    "        target: (-np.inf if metric_preferences[target] == 'max' else np.inf)\n",
    "        for target in targets\n",
    "    }\n",
    "    base_best_info = {}\n",
    "\n",
    "    # Compute optimal baseline value at each k: for each target, aggregate across\n",
    "    # all baseline methods by taking max (if higher-is-better) or min (if lower-is-better)\n",
    "    baseline_agg = {}\n",
    "    for target in targets:\n",
    "        direction = metric_preferences[target]\n",
    "        col = f'avg_{target}'\n",
    "        if direction == 'max':\n",
    "            baseline_agg[col] = base_slice.groupby('k')[col].max()\n",
    "        else:\n",
    "            baseline_agg[col] = base_slice.groupby('k')[col].min()\n",
    "    baseline = pd.DataFrame(baseline_agg).dropna()\n",
    "    if baseline.empty:\n",
    "        return\n",
    "\n",
    "    # Determine baseline best k per target\n",
    "    for target in targets:\n",
    "        direction = metric_preferences[target]\n",
    "        base_series = baseline[f'avg_{target}'].dropna()\n",
    "        if base_series.empty:\n",
    "            continue\n",
    "        k_star = base_series.idxmax() if direction == 'max' else base_series.idxmin()\n",
    "        base_best_val = base_series.loc[k_star]\n",
    "        if base_best_val == 0 or not np.isfinite(base_best_val):\n",
    "            continue\n",
    "        base_best_info[target] = (k_star, base_best_val)\n",
    "\n",
    "    # Iterate MineDraft variants\n",
    "    for _, mine_indexes in mine_slice.groupby('method').groups.items():\n",
    "        mine_grouped = mine_slice.loc[mine_indexes].groupby('k').first().dropna()\n",
    "        if len(mine_indexes) != len(mine_grouped) or len(mine_grouped) == 0:\n",
    "            print(f\"{BOLD}{YELLOW}[WARNING]{RESET} NaN or duplicates found in MineDraft results for {context}\")\n",
    "\n",
    "        common_k = std_sd.index.intersection(mine_grouped.index)\n",
    "        if common_k.empty:\n",
    "            continue\n",
    "\n",
    "        for target in targets:\n",
    "            direction = metric_preferences[target]\n",
    "            # Max improvement over Std SD across k\n",
    "            comparison = (\n",
    "                pd.DataFrame({\n",
    "                    'mine': mine_grouped.loc[common_k, f'avg_{target}'],\n",
    "                    'std': std_sd.loc[common_k, f'avg_{target}'],\n",
    "                })\n",
    "                .dropna()\n",
    "            )\n",
    "            comparison = comparison[comparison['std'] != 0]\n",
    "            if not comparison.empty:\n",
    "                if direction == 'max':\n",
    "                    improvements = (comparison['mine'] - comparison['std']) / comparison['std'] * 100\n",
    "                else:\n",
    "                    improvements = (comparison['std'] - comparison['mine']) / comparison['std'] * 100\n",
    "                if not improvements.empty:\n",
    "                    max_improvements[target] = max(max_improvements[target], improvements.max())\n",
    "\n",
    "            # MineDraft value at baseline's best k\n",
    "            if target in base_best_info:\n",
    "                k_star, _ = base_best_info[target]\n",
    "                if k_star in mine_grouped.index:\n",
    "                    v = mine_grouped.loc[k_star, f'avg_{target}']\n",
    "                    if np.isfinite(v):\n",
    "                        if direction == 'max':\n",
    "                            mine_at_base_k[target] = max(mine_at_base_k[target], v)\n",
    "                        else:\n",
    "                            mine_at_base_k[target] = min(mine_at_base_k[target], v)\n",
    "\n",
    "    # Print per-target report\n",
    "    for target in targets:\n",
    "        if target not in base_best_info or not np.isfinite(max_improvements[target]):\n",
    "            continue\n",
    "        k_star, base_best_val = base_best_info[target]\n",
    "        mine_val = mine_at_base_k[target]\n",
    "        if not np.isfinite(mine_val) or base_best_val == 0:\n",
    "            continue\n",
    "\n",
    "        direction = metric_preferences[target]\n",
    "        if direction == 'max':\n",
    "            improvement_over_best_base = (mine_val - base_best_val) / base_best_val * 100\n",
    "        else:\n",
    "            improvement_over_best_base = (base_best_val - mine_val) / base_best_val * 100\n",
    "\n",
    "        print(\n",
    "            f\"{context} [{target}] (Base-best@k={k_star}): \" +\n",
    "            (f\"max ΔStdSD={max_improvements[target]:.2f}% | \" if max_improvements[target] > 0 else f\"{BOLD}{YELLOW}max ΔStdSD={max_improvements[target]:.2f}%{RESET} | \") +\n",
    "            (f\"ΔbestBaseline={improvement_over_best_base:.2f}%\" if improvement_over_best_base > 0 else f\"{BOLD}{YELLOW}ΔbestBaseline={improvement_over_best_base:.2f}%{RESET}\")\n",
    "        )\n",
    "\n",
    "def report_minedraft_improvement_vs_e(metrics_slice: pd.DataFrame, targets: list[str], context: str):\n",
    "    std_slice = metrics_slice[metrics_slice['method'].str.startswith('Standard SD')]\n",
    "    mine_slice = metrics_slice[\n",
    "        metrics_slice['method'].str.startswith('MineDraft')\n",
    "        & ~metrics_slice['method'].str.contains(r'\\+')\n",
    "    ]\n",
    "    if std_slice.empty:\n",
    "        std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE')]\n",
    "        mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE')]\n",
    "    if std_slice.empty:\n",
    "        std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE-3')]\n",
    "        mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE-3')]\n",
    "\n",
    "    if std_slice.empty or mine_slice.empty:\n",
    "        return\n",
    "\n",
    "    std_sd = std_slice.groupby('e').first().dropna()\n",
    "    if std_sd.empty:\n",
    "        return\n",
    "\n",
    "    # 1) Max improvement over Std SD across es\n",
    "    max_improvements = {target: -np.inf for target in targets}\n",
    "\n",
    "    # Iterate MineDraft variants\n",
    "    for _, mine_indexes in mine_slice.groupby('method').groups.items():\n",
    "        mine_grouped = mine_slice.loc[mine_indexes].groupby('e').first().dropna()\n",
    "        if len(mine_indexes) != len(mine_grouped) or len(mine_grouped) == 0:\n",
    "            print(f\"{BOLD}{YELLOW}[WARNING]{RESET} NaN or duplicates found in MineDraft results for {context}\")\n",
    "\n",
    "        common_e = std_sd.index.intersection(mine_grouped.index)\n",
    "        if common_e.empty:\n",
    "            continue\n",
    "\n",
    "        for target in targets:\n",
    "            direction = metric_preferences[target]\n",
    "            # Max improvement over Std SD across e\n",
    "            comparison = (\n",
    "                pd.DataFrame({\n",
    "                    'mine': mine_grouped.loc[common_e, f'avg_{target}'],\n",
    "                    'std': std_sd.loc[common_e, f'avg_{target}'],\n",
    "                })\n",
    "                .dropna()\n",
    "            )\n",
    "            comparison = comparison[comparison['std'] != 0]\n",
    "            if not comparison.empty:\n",
    "                if direction == 'max':\n",
    "                    improvements = (comparison['mine'] - comparison['std']) / comparison['std'] * 100\n",
    "                else:\n",
    "                    improvements = (comparison['std'] - comparison['mine']) / comparison['std'] * 100\n",
    "                if not improvements.empty:\n",
    "                    max_improvements[target] = max(max_improvements[target], improvements.max())\n",
    "\n",
    "    # Print per-target report\n",
    "    for target in targets:\n",
    "        print(\n",
    "            f\"{context} [{target}] : \" +\n",
    "            (f\"max ΔStdSD={max_improvements[target]:.2f}%\" if max_improvements[target] > 0 else f\"{BOLD}{YELLOW}max ΔStdSD={max_improvements[target]:.2f}%{RESET}\")\n",
    "        )\n",
    "\n",
    "\n",
    "def report_minedraft_improvement_by(metrics_df: pd.DataFrame, targets: list[str], group_by_col: str, context: str):\n",
    "    if group_by_col not in metrics_df.columns:\n",
    "        print(f\"{BOLD}{YELLOW}[ERROR]{RESET} Column '{group_by_col}' not found in DataFrame\")\n",
    "        return\n",
    "\n",
    "    unique_groups = metrics_df[group_by_col].unique()\n",
    "\n",
    "    for group_val in sorted(unique_groups):\n",
    "        metrics_slice = metrics_df[metrics_df[group_by_col] == group_val]\n",
    "\n",
    "        base_slice = metrics_slice[~metrics_slice['method'].str.startswith('MineDraft')]\n",
    "        std_slice = metrics_slice[metrics_slice['method'].str.startswith('Standard SD')]\n",
    "        mine_slice = metrics_slice[\n",
    "            metrics_slice['method'].str.startswith('MineDraft')\n",
    "            & ~metrics_slice['method'].str.contains(r'\\+')\n",
    "        ]\n",
    "        if std_slice.empty:\n",
    "            std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE')]\n",
    "            mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE')]\n",
    "        if std_slice.empty:\n",
    "            std_slice = metrics_slice[metrics_slice['method'].str.startswith('EAGLE-3')]\n",
    "            mine_slice = metrics_slice[metrics_slice['method'].str.startswith('MineDraft + EAGLE-3')]\n",
    "\n",
    "        if base_slice.empty or std_slice.empty or mine_slice.empty:\n",
    "            continue\n",
    "\n",
    "        std_sd = std_slice.groupby('k').first().dropna()\n",
    "        if std_sd.empty:\n",
    "            return\n",
    "\n",
    "        # 1) Max improvement over Std SD across ks\n",
    "        max_improvements = {target: -np.inf for target in targets}\n",
    "\n",
    "        # 2) Improvement at the k where baseline is best\n",
    "        mine_at_base_k = {\n",
    "            target: (-np.inf if metric_preferences[target] == 'max' else np.inf)\n",
    "            for target in targets\n",
    "        }\n",
    "        base_best_info = {}\n",
    "\n",
    "        # Compute optimal baseline value at each k: for each target, aggregate across\n",
    "        # all baseline methods by taking max (if higher-is-better) or min (if lower-is-better)\n",
    "        baseline_agg = {}\n",
    "        for target in targets:\n",
    "            direction = metric_preferences[target]\n",
    "            col = f'avg_{target}'\n",
    "            if direction == 'max':\n",
    "                baseline_agg[col] = base_slice.groupby('k')[col].max()\n",
    "            else:\n",
    "                baseline_agg[col] = base_slice.groupby('k')[col].min()\n",
    "        baseline = pd.DataFrame(baseline_agg).dropna()\n",
    "        if baseline.empty:\n",
    "            return\n",
    "\n",
    "        # Determine baseline best k per target\n",
    "        for target in targets:\n",
    "            direction = metric_preferences[target]\n",
    "            base_series = baseline[f'avg_{target}'].dropna()\n",
    "            if base_series.empty:\n",
    "                continue\n",
    "            k_star = base_series.idxmax() if direction == 'max' else base_series.idxmin()\n",
    "            base_best_val = base_series.loc[k_star]\n",
    "            if base_best_val == 0 or not np.isfinite(base_best_val):\n",
    "                continue\n",
    "            base_best_info[target] = (k_star, base_best_val)\n",
    "\n",
    "        # Iterate MineDraft variants\n",
    "        for _, mine_indexes in mine_slice.groupby('method').groups.items():\n",
    "            mine_grouped = mine_slice.loc[mine_indexes].groupby('k').first().dropna()\n",
    "            if len(mine_indexes) != len(mine_grouped) or len(mine_grouped) == 0:\n",
    "                print(f\"{BOLD}{YELLOW}[WARNING]{RESET} {context}: NaN or duplicates found in MineDraft results for {group_by_col}={group_val}\")\n",
    "            \n",
    "            common_k = std_sd.index.intersection(mine_grouped.index)\n",
    "            if common_k.empty:\n",
    "                continue\n",
    "\n",
    "            for target in targets:\n",
    "                direction = metric_preferences[target]\n",
    "                # Max improvement over Std SD across k\n",
    "                comparison = (\n",
    "                    pd.DataFrame({\n",
    "                        'mine': mine_grouped.loc[common_k, f'avg_{target}'],\n",
    "                        'std': std_sd.loc[common_k, f'avg_{target}'],\n",
    "                    })\n",
    "                    .dropna()\n",
    "                )\n",
    "                comparison = comparison[comparison['std'] != 0]\n",
    "                if not comparison.empty:\n",
    "                    if direction == 'max':\n",
    "                        improvements = (comparison['mine'] - comparison['std']) / comparison['std'] * 100\n",
    "                    else:\n",
    "                        improvements = (comparison['std'] - comparison['mine']) / comparison['std'] * 100\n",
    "                    if not improvements.empty:\n",
    "                        max_improvements[target] = max(max_improvements[target], improvements.max())\n",
    "\n",
    "                # MineDraft value at Std SD's best k\n",
    "                if target in base_best_info:\n",
    "                    k_star, _ = base_best_info[target]\n",
    "                    if k_star in mine_grouped.index:\n",
    "                        v = mine_grouped.loc[k_star, f'avg_{target}']\n",
    "                        if np.isfinite(v):\n",
    "                            if direction == 'max':\n",
    "                                mine_at_base_k[target] = max(mine_at_base_k[target], v)\n",
    "                            else:\n",
    "                                mine_at_base_k[target] = min(mine_at_base_k[target], v)\n",
    "\n",
    "        # Print per-target report\n",
    "        for target in targets:\n",
    "            if target not in base_best_info or not np.isfinite(max_improvements[target]):\n",
    "                continue\n",
    "            k_star, base_best_val = base_best_info[target]\n",
    "            mine_val = mine_at_base_k[target]\n",
    "            if not np.isfinite(mine_val) or base_best_val == 0:\n",
    "                continue\n",
    "\n",
    "            direction = metric_preferences[target]\n",
    "            if direction == 'max':\n",
    "                improvement_over_best_base = (mine_val - base_best_val) / base_best_val * 100\n",
    "            else:\n",
    "                improvement_over_best_base = (base_best_val - mine_val) / base_best_val * 100\n",
    "\n",
    "            print(\n",
    "                f\"{context} {group_by_col}={group_val} [{target}] (Std-best@k={k_star}): \" +\n",
    "                (f\"max ΔStdSD={max_improvements[target]:.2f}% | \" if max_improvements[target] > 0 else f\"{BOLD}{YELLOW}max ΔStdSD={max_improvements[target]:.2f}%{RESET} | \") +\n",
    "                (f\"ΔbestBaseline={improvement_over_best_base:.2f}%\" if improvement_over_best_base > 0 else f\"{BOLD}{YELLOW}ΔbestBaseline={improvement_over_best_base:.2f}%{RESET}\")\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2f1f42d",
   "metadata": {},
   "source": [
    "### n = 1 & default c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f5c3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "# Index e=0 data (without TETRIS)\n",
    "no_tetris_bundles = {}\n",
    "for setting_group, setting_group_indexes in df[(df['e'] == 0) & (df['n'] == 1)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\"]).groups.items():\n",
    "\n",
    "    dataset_bundles = {}\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_group_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset_bundles[dataset_name] = dataset_indexes\n",
    "\n",
    "    no_tetris_bundles[setting_group] = dataset_bundles\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs), setting_indexes in df[(df['e'] > 0) & (df['n'] == 1) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, e), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"e\"]).groups.items():\n",
    "            # if \"pearl\" in method_name or e > 3 or method_name.startswith(\"tetris\") and e > 2:\n",
    "            #     continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (extra={e})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}): \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Find matching e=0 data\n",
    "        setting_key = (target_model, draft_model, batch_size, reqs)\n",
    "        dataset_key = dataset_name\n",
    "        if setting_key in no_tetris_bundles and dataset_key in no_tetris_bundles[setting_key]:\n",
    "            no_tetris_dataset_indexes = no_tetris_bundles[setting_key][dataset_key]\n",
    "\n",
    "            # Insert e=0 data\n",
    "            for method_name, method_indexes in df.loc[no_tetris_dataset_indexes].groupby(\"method\").groups.items():\n",
    "                # if \"pearl\" in method_name:\n",
    "                #     continue\n",
    "\n",
    "                method = {\"name\": method_labels[method_name]}\n",
    "                max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "                metrics_by_k_and_percentile = {\n",
    "                    k: {\n",
    "                        1: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.8: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.7: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []}\n",
    "                    } for k in range(1, max_k + 1)\n",
    "                }\n",
    "\n",
    "                # Collect metrics for all bundles by k value\n",
    "                for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                    for i, bundle in df.loc[indexes].iterrows():\n",
    "                        # Print total times of preemption\n",
    "                        if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                            print(\n",
    "                                f\"{YELLOW}[WARN]{RESET} \"\n",
    "                                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                                f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                                f\"method={method_name}, k={k}: \"\n",
    "                                f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                            )\n",
    "\n",
    "                        for percentile in [1, 0.8, 0.7]:\n",
    "                            bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                            gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                            vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                            e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                            # Store metrics for this bundle by k and percentile\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"gen_throughputs\"].append(gen_throughput)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"vsrs\"].append(vsr)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"latencies\"].append(e2el)\n",
    "\n",
    "                metrics_by_percentile = {\n",
    "                    percentile: {\n",
    "                        \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"avg_vsr\": [np.nan] * max_k,\n",
    "                        \"std_vsr\": [np.nan] * max_k,\n",
    "                        \"avg_latency\": [np.nan] * max_k,\n",
    "                        \"std_latency\": [np.nan] * max_k,\n",
    "                    } for percentile in [1, 0.8, 0.7]\n",
    "                }\n",
    "\n",
    "                # Calculate means and stds\n",
    "                for k in range(1, max_k + 1):\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                        # Skip if no data collected\n",
    "                        if not metrics[\"gen_throughputs\"]:\n",
    "                            continue\n",
    "\n",
    "                        # Calculate means\n",
    "                        avg_gen = np.mean(metrics[\"gen_throughputs\"])\n",
    "                        avg_vsr = np.mean(metrics[\"vsrs\"])\n",
    "                        avg_latency = np.mean(metrics[\"latencies\"])\n",
    "\n",
    "                        # Calculate stds\n",
    "                        std_gen = np.std(metrics[\"gen_throughputs\"])\n",
    "                        std_vsr = np.std(metrics[\"vsrs\"])\n",
    "                        std_latency = np.std(metrics[\"latencies\"])\n",
    "\n",
    "                        # Store back in the dictionary\n",
    "                        metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                        metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                        metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                        metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                        metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                        metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "                trims = []\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    trims.append({\n",
    "                        \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                        **metrics_by_percentile[percentile]\n",
    "                    })\n",
    "\n",
    "                method[\"trims\"] = trims\n",
    "                methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, trim={percentile}\"\n",
    "            )\n",
    "            report_minedraft_improvement(\n",
    "                metrics_df.loc[percentile_indexes],\n",
    "                targets,\n",
    "                report_context\n",
    "            )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png'\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "390bfcea",
   "metadata": {},
   "source": [
    "### n = 1 & diff draft models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe03569e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency', 'vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "\n",
    "for (target_model, batch_size, reqs), setting_indexes in df[(df['e'] == 0) & (df['n'] == 1) & (df['c'] <= 0) & (df['method'] == \"psd\") & (df['batch_size'] == 16)].groupby(\n",
    "        [\"target_model\", \"batch_size\", \"reqs\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        draft_models = {}\n",
    "\n",
    "        for draft_model_name, draft_model_indexes in df.loc[dataset_indexes].groupby(\"draft_model\").groups.items():\n",
    "            draft_model = {\"name\": draft_model_name}\n",
    "            max_k = df.loc[draft_model_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[draft_model_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model_name}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method=psd, k={k}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            draft_model[\"trims\"] = trims\n",
    "            draft_models.update({draft_model[\"name\"]: draft_model[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for draft_model_name, trims in draft_models.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'draft_model': draft_model_name,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for draft_model_name, draft_model_indexes in metrics_df.loc[percentile_indexes].groupby('draft_model').groups.items():\n",
    "                grouped = metrics_df.loc[draft_model_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=draft_model_name, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde07f48",
   "metadata": {},
   "source": [
    "### n = 1 & fixed c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a6891c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "# Index e=0 data (without TETRIS)\n",
    "no_tetris_bundles = {}\n",
    "no_tetris_df = df.loc[(df['e'] == 0) & (df['n'] == 1)].copy()\n",
    "no_tetris_df['c'] = no_tetris_df['batch_size'] * no_tetris_df['k']\n",
    "for setting_group, setting_group_indexes in no_tetris_df.groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"c\"]).groups.items():\n",
    "\n",
    "    dataset_bundles = {}\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_group_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset_bundles[dataset_name] = dataset_indexes\n",
    "\n",
    "    no_tetris_bundles[setting_group] = dataset_bundles\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, c), setting_indexes in df[(df['e'] > 0) & (df['n'] == 1) & (df['c'] > 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"c\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, e), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"e\"]).groups.items():\n",
    "            # if method_name == \"tetris\":\n",
    "            #     continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (extra={e})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), c={c}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Find matching e=0 data\n",
    "        setting_key = (target_model, draft_model, batch_size, reqs, c)\n",
    "        dataset_key = dataset_name\n",
    "        if setting_key in no_tetris_bundles and dataset_key in no_tetris_bundles[setting_key]:\n",
    "            no_tetris_dataset_indexes = no_tetris_bundles[setting_key][dataset_key]\n",
    "\n",
    "            # Insert e=0 data\n",
    "            for method_name, method_indexes in df.loc[no_tetris_dataset_indexes].groupby(\"method\").groups.items():\n",
    "                method = {\"name\": method_labels[method_name]}\n",
    "                max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "                metrics_by_k_and_percentile = {\n",
    "                    k: {\n",
    "                        1: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.8: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.7: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []}\n",
    "                    } for k in range(1, max_k + 1)\n",
    "                }\n",
    "\n",
    "                # Collect metrics for all bundles by k value\n",
    "                for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                    for i, bundle in df.loc[indexes].iterrows():\n",
    "                        # Print total times of preemption\n",
    "                        if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                            print(\n",
    "                                f\"{YELLOW}[WARN]{RESET} \"\n",
    "                                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                                f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                                f\"method={method_name}, k={k}, c={c}: \"\n",
    "                                f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                            )\n",
    "\n",
    "                        for percentile in [1, 0.8, 0.7]:\n",
    "                            bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                            gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                            vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                            e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                            # Store metrics for this bundle by k and percentile\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"gen_throughputs\"].append(gen_throughput)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"vsrs\"].append(vsr)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"latencies\"].append(e2el)\n",
    "\n",
    "                metrics_by_percentile = {\n",
    "                    percentile: {\n",
    "                        \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"avg_vsr\": [np.nan] * max_k,\n",
    "                        \"std_vsr\": [np.nan] * max_k,\n",
    "                        \"avg_latency\": [np.nan] * max_k,\n",
    "                        \"std_latency\": [np.nan] * max_k,\n",
    "                    } for percentile in [1, 0.8, 0.7]\n",
    "                }\n",
    "\n",
    "                # Calculate means and stds\n",
    "                for k in range(1, max_k + 1):\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                        # Skip if no data collected\n",
    "                        if not metrics[\"gen_throughputs\"]:\n",
    "                            continue\n",
    "\n",
    "                        # Calculate means\n",
    "                        avg_gen = np.mean(metrics[\"gen_throughputs\"])\n",
    "                        avg_vsr = np.mean(metrics[\"vsrs\"])\n",
    "                        avg_latency = np.mean(metrics[\"latencies\"])\n",
    "\n",
    "                        # Calculate stds\n",
    "                        std_gen = np.std(metrics[\"gen_throughputs\"])\n",
    "                        std_vsr = np.std(metrics[\"vsrs\"])\n",
    "                        std_latency = np.std(metrics[\"latencies\"])\n",
    "\n",
    "                        # Store back in the dictionary\n",
    "                        metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                        metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                        metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                        metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                        metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                        metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "                trims = []\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    trims.append({\n",
    "                        \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                        **metrics_by_percentile[percentile]\n",
    "                    })\n",
    "\n",
    "                method[\"trims\"] = trims\n",
    "                methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    if e == 0:\n",
    "                        # Draw dotted horizontal line for non-tetris methods\n",
    "                        axs.axhline(y=grouped[f'avg_{target}'].iloc[0], linestyle='--', label=method)\n",
    "                        axs.fill_between(range(1, max_k+1),\n",
    "                                         grouped[f'avg_{target}'] - grouped[f'std_{target}'],\n",
    "                                         grouped[f'avg_{target}'] + grouped[f'std_{target}'],\n",
    "                                         alpha=0.2)\n",
    "\n",
    "                    else:\n",
    "                        # Plot line with markers\n",
    "                        axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                        axs.fill_between(grouped.index, \n",
    "                                         grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                         grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                         alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, c={c}, trim={percentile}\"\n",
    "            )\n",
    "            report_minedraft_improvement(\n",
    "                metrics_df.loc[percentile_indexes],\n",
    "                targets,\n",
    "                report_context\n",
    "            )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_c{c}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png'\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "681b0508",
   "metadata": {},
   "source": [
    "### n > 1 & default c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c43f9d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "# Index e=0 data (without TETRIS)\n",
    "no_tetris_bundles = {}\n",
    "for setting_group, setting_group_indexes in df[(df['e'] == 0) & (df['n'] > 1)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\"]).groups.items():\n",
    "\n",
    "    dataset_bundles = {}\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_group_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset_bundles[dataset_name] = dataset_indexes\n",
    "\n",
    "    no_tetris_bundles[setting_group] = dataset_bundles\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, n), setting_indexes in df[(df['e'] > 0) & (df['n'] > 1) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, e), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"e\"]).groups.items():\n",
    "            # if \"pearl\" in method_name or e > 3 or method_name.startswith(\"tetris\") and e > 2:\n",
    "            #     continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (extra={e})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), n={n}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Find matching e=0 data\n",
    "        setting_key = (target_model, draft_model, batch_size, reqs, n)\n",
    "        dataset_key = dataset_name\n",
    "        if setting_key in no_tetris_bundles and dataset_key in no_tetris_bundles[setting_key]:\n",
    "            no_tetris_dataset_indexes = no_tetris_bundles[setting_key][dataset_key]\n",
    "\n",
    "            # Insert e=0 data\n",
    "            for method_name, method_indexes in df.loc[no_tetris_dataset_indexes].groupby(\"method\").groups.items():\n",
    "                # if \"pearl\" in method_name:\n",
    "                #     continue\n",
    "\n",
    "                method = {\"name\": method_labels[method_name]}\n",
    "                max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "                metrics_by_k_and_percentile = {\n",
    "                    k: {\n",
    "                        1: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.8: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.7: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []}\n",
    "                    } for k in range(1, max_k + 1)\n",
    "                }\n",
    "\n",
    "                # Collect metrics for all bundles by k value\n",
    "                for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                    for i, bundle in df.loc[indexes].iterrows():\n",
    "                        # Print total times of preemption\n",
    "                        if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                            print(\n",
    "                                f\"{YELLOW}[WARN]{RESET} \"\n",
    "                                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                                f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                                f\"method={method_name}, k={k}, n={n}: \"\n",
    "                                f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                            )\n",
    "\n",
    "                        for percentile in [1, 0.8, 0.7]:\n",
    "                            bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                            gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                            vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                            e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                            # Store metrics for this bundle by k and percentile\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"gen_throughputs\"].append(gen_throughput)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"vsrs\"].append(vsr)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"latencies\"].append(e2el)\n",
    "\n",
    "                metrics_by_percentile = {\n",
    "                    percentile: {\n",
    "                        \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"avg_vsr\": [np.nan] * max_k,\n",
    "                        \"std_vsr\": [np.nan] * max_k,\n",
    "                        \"avg_latency\": [np.nan] * max_k,\n",
    "                        \"std_latency\": [np.nan] * max_k,\n",
    "                    } for percentile in [1, 0.8, 0.7]\n",
    "                }\n",
    "\n",
    "                # Calculate means and stds\n",
    "                for k in range(1, max_k + 1):\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                        # Skip if no data collected\n",
    "                        if not metrics[\"gen_throughputs\"]:\n",
    "                            continue\n",
    "\n",
    "                        # Calculate means\n",
    "                        avg_gen = np.mean(metrics[\"gen_throughputs\"])\n",
    "                        avg_vsr = np.mean(metrics[\"vsrs\"])\n",
    "                        avg_latency = np.mean(metrics[\"latencies\"])\n",
    "\n",
    "                        # Calculate stds\n",
    "                        std_gen = np.std(metrics[\"gen_throughputs\"])\n",
    "                        std_vsr = np.std(metrics[\"vsrs\"])\n",
    "                        std_latency = np.std(metrics[\"latencies\"])\n",
    "\n",
    "                        # Store back in the dictionary\n",
    "                        metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                        metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                        metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                        metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                        metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                        metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "                trims = []\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    trims.append({\n",
    "                        \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                        **metrics_by_percentile[percentile]\n",
    "                    })\n",
    "\n",
    "                method[\"trims\"] = trims\n",
    "                methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, n={n}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_minedraft_improvement(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_n{n}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dae97614",
   "metadata": {},
   "source": [
    "### n > 1 & diff draft models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b0cc074",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency', 'vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, batch_size, reqs, n), setting_indexes in df[(df['e'] == 0) & (df['n'] > 1) & (df['c'] <= 0) & (df['method'] == \"psd\")].groupby(\n",
    "        [\"target_model\", \"batch_size\", \"reqs\", \"n\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        draft_models = {}\n",
    "\n",
    "        for draft_model_name, draft_model_indexes in df.loc[dataset_indexes].groupby(\"draft_model\").groups.items():\n",
    "            draft_model = {\"name\": draft_model_name}\n",
    "            max_k = df.loc[draft_model_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[draft_model_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model_name}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method=psd, k={k}, n={n}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            draft_model[\"trims\"] = trims\n",
    "            draft_models.update({draft_model[\"name\"]: draft_model[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for draft_model_name, trims in draft_models.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'draft_model': draft_model_name,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for draft_model_name, draft_model_indexes in metrics_df.loc[percentile_indexes].groupby('draft_model').groups.items():\n",
    "                grouped = metrics_df.loc[draft_model_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=draft_model_name, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, n={n}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_n{n}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53d6fcd4",
   "metadata": {},
   "source": [
    "### n > 1 & fixed c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2bf524a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "# Index e=0 data (without TETRIS)\n",
    "no_tetris_bundles = {}\n",
    "no_tetris_df = df.loc[(df['e'] == 0) & (df['n'] > 1)].copy()\n",
    "no_tetris_df['c'] = no_tetris_df['batch_size'] * no_tetris_df['k']\n",
    "for setting_group, setting_group_indexes in no_tetris_df.groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\", \"c\"]).groups.items():\n",
    "\n",
    "    dataset_bundles = {}\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_group_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset_bundles[dataset_name] = dataset_indexes\n",
    "\n",
    "    no_tetris_bundles[setting_group] = dataset_bundles\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, n, c), setting_indexes in df[(df['e'] > 0) & (df['n'] > 1) & (df['c'] > 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\", \"c\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, e), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"e\"]).groups.items():\n",
    "            # if method_name == \"tetris\":\n",
    "            #     continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (extra={e})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), n={n}, c={c}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Find matching e=0 data\n",
    "        setting_key = (target_model, draft_model, batch_size, reqs, n, c)\n",
    "        dataset_key = dataset_name\n",
    "        if setting_key in no_tetris_bundles and dataset_key in no_tetris_bundles[setting_key]:\n",
    "            no_tetris_dataset_indexes = no_tetris_bundles[setting_key][dataset_key]\n",
    "\n",
    "            # Insert e=0 data\n",
    "            for method_name, method_indexes in df.loc[no_tetris_dataset_indexes].groupby(\"method\").groups.items():\n",
    "                method = {\"name\": method_labels[method_name]}\n",
    "                max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "                metrics_by_k_and_percentile = {\n",
    "                    k: {\n",
    "                        1: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.8: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []},\n",
    "                        0.7: {\"gen_throughputs\": [], \"vsrs\": [], \"latencies\": []}\n",
    "                    } for k in range(1, max_k + 1)\n",
    "                }\n",
    "\n",
    "                # Collect metrics for all bundles by k value\n",
    "                for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                    for i, bundle in df.loc[indexes].iterrows():\n",
    "                        # Print total times of preemption\n",
    "                        if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                            print(\n",
    "                                f\"{YELLOW}[WARN]{RESET} \"\n",
    "                                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                                f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                                f\"method={method_name}, k={k}, n={n}, c={c}: \"\n",
    "                                f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                            )\n",
    "\n",
    "                        for percentile in [1, 0.8, 0.7]:\n",
    "                            bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                            gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                            vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                            e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                            # Store metrics for this bundle by k and percentile\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"gen_throughputs\"].append(gen_throughput)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"vsrs\"].append(vsr)\n",
    "                            metrics_by_k_and_percentile[k][percentile][\"latencies\"].append(e2el)\n",
    "\n",
    "                metrics_by_percentile = {\n",
    "                    percentile: {\n",
    "                        \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                        \"avg_vsr\": [np.nan] * max_k,\n",
    "                        \"std_vsr\": [np.nan] * max_k,\n",
    "                        \"avg_latency\": [np.nan] * max_k,\n",
    "                        \"std_latency\": [np.nan] * max_k,\n",
    "                    } for percentile in [1, 0.8, 0.7]\n",
    "                }\n",
    "\n",
    "                # Calculate means and stds\n",
    "                for k in range(1, max_k + 1):\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                        # Skip if no data collected\n",
    "                        if not metrics[\"gen_throughputs\"]:\n",
    "                            continue\n",
    "\n",
    "                        # Calculate means\n",
    "                        avg_gen = np.mean(metrics[\"gen_throughputs\"])\n",
    "                        avg_vsr = np.mean(metrics[\"vsrs\"])\n",
    "                        avg_latency = np.mean(metrics[\"latencies\"])\n",
    "\n",
    "                        # Calculate stds\n",
    "                        std_gen = np.std(metrics[\"gen_throughputs\"])\n",
    "                        std_vsr = np.std(metrics[\"vsrs\"])\n",
    "                        std_latency = np.std(metrics[\"latencies\"])\n",
    "\n",
    "                        # Store back in the dictionary\n",
    "                        metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                        metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                        metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                        metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                        metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                        metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "                trims = []\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    trims.append({\n",
    "                        \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                        **metrics_by_percentile[percentile]\n",
    "                    })\n",
    "\n",
    "                method[\"trims\"] = trims\n",
    "                methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    if e == 0:\n",
    "                        # Draw dotted horizontal line for non-tetris methods\n",
    "                        axs.axhline(y=grouped[f'avg_{target}'].iloc[0], linestyle='--', label=method)\n",
    "                        axs.fill_between(range(1, max_k+1),\n",
    "                                         grouped[f'avg_{target}'].iloc[0] - grouped[f'std_{target}'].iloc[0],\n",
    "                                         grouped[f'avg_{target}'].iloc[0] + grouped[f'std_{target}'].iloc[0],\n",
    "                                         alpha=0.2)\n",
    "\n",
    "                    else:\n",
    "                        # Plot line with markers\n",
    "                        axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                        axs.fill_between(grouped.index, \n",
    "                                         grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                         grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                         alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, c={c}, n={n}, trim={percentile}\"\n",
    "            )\n",
    "            report_minedraft_improvement(\n",
    "                metrics_df.loc[percentile_indexes],\n",
    "                targets,\n",
    "                report_context\n",
    "            )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_c{c}_n{n}_'\n",
    "                            f'{percentile}_{target}'\n",
    "                        ) + '.png'\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ccbad3",
   "metadata": {},
   "source": [
    "## Ablation Studies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fbe1455",
   "metadata": {},
   "source": [
    "### VSR vs _e_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aedc9bfc",
   "metadata": {},
   "source": [
    "#### n = 1 & default c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c820ea2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs), setting_indexes in df[(df['e'] > 0) & (df['n'] == 1) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, k), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"k\"]).groups.items():\n",
    "            if not method_name.startswith(\"ptetris\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (k={k})'}\n",
    "            max_e = df.loc[method_indexes][\"e\"].max()\n",
    "\n",
    "            metrics_by_e_and_percentile = {\n",
    "                e: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for e in range(1, max_e + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by e value\n",
    "            for e, indexes in df.loc[method_indexes].groupby(\"e\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}): \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by e and percentile\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"avg_vsr\": [np.nan] * max_e,\n",
    "                    \"std_vsr\": [np.nan] * max_e,\n",
    "                    \"avg_latency\": [np.nan] * max_e,\n",
    "                    \"std_latency\": [np.nan] * max_e,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for e in range(1, max_e + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_e_and_percentile[e][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][e-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][e-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][e-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][e-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][e-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][e-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for e, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'e': e,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows).dropna()\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('e').first()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Extra Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_minedraft_improvement_vs_e(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            e_vals = np.sort(metrics_df.loc[percentile_indexes, 'e'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(e_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_'\n",
    "                            f'{percentile}_{target}_vs_e'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6935e46a",
   "metadata": {},
   "source": [
    "#### n = 1 & fixed c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b9fccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, c), setting_indexes in df[(df['e'] > 0) & (df['n'] == 1) & (df['c'] > 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"c\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, k), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"k\"]).groups.items():\n",
    "            if not method_name.startswith(\"ptetris\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (k={k})'}\n",
    "            max_e = df.loc[method_indexes][\"e\"].max()\n",
    "\n",
    "            metrics_by_e_and_percentile = {\n",
    "                e: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for e in range(1, max_e + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by e value\n",
    "            for e, indexes in df.loc[method_indexes].groupby(\"e\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), c={c}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by e and percentile\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"avg_vsr\": [np.nan] * max_e,\n",
    "                    \"std_vsr\": [np.nan] * max_e,\n",
    "                    \"avg_latency\": [np.nan] * max_e,\n",
    "                    \"std_latency\": [np.nan] * max_e,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for e in range(1, max_e + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_e_and_percentile[e][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][e-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][e-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][e-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][e-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][e-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][e-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for e, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'e': e,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows).dropna()\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('e').first()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Extra Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, c={c}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_minedraft_improvement_vs_e(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            e_vals = np.sort(metrics_df.loc[percentile_indexes, 'e'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(e_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_c{c}_'\n",
    "                            f'{percentile}_{target}_vs_e'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a9eb254",
   "metadata": {},
   "source": [
    "#### n > 1 & default c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bfdeef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, n), setting_indexes in df[(df['e'] > 0) & (df['n'] > 1) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, k), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"k\"]).groups.items():\n",
    "            if not method_name.startswith(\"ptetris\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (k={k})'}\n",
    "            max_e = df.loc[method_indexes][\"e\"].max()\n",
    "\n",
    "            metrics_by_e_and_percentile = {\n",
    "                e: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for e in range(1, max_e + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by e value\n",
    "            for e, indexes in df.loc[method_indexes].groupby(\"e\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), n={n}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by e and percentile\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"avg_vsr\": [np.nan] * max_e,\n",
    "                    \"std_vsr\": [np.nan] * max_e,\n",
    "                    \"avg_latency\": [np.nan] * max_e,\n",
    "                    \"std_latency\": [np.nan] * max_e,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for e in range(1, max_e + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_e_and_percentile[e][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][e-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][e-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][e-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][e-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][e-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][e-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for e, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'e': e,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows).dropna()\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('e').first()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Extra Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, n={n}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_minedraft_improvement_vs_e(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            e_vals = np.sort(metrics_df.loc[percentile_indexes, 'e'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(e_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_n{n}_'\n",
    "                            f'{percentile}_{target}_vs_e'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df55a51d",
   "metadata": {},
   "source": [
    "#### n > 1 & fixed c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4f78378",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['vsr']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, n, c), setting_indexes in df[(df['e'] > 0) & (df['n'] > 1) & (df['c'] > 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\", \"c\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, k), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"k\"]).groups.items():\n",
    "            if not method_name.startswith(\"ptetris\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (k={k})'}\n",
    "            max_e = df.loc[method_indexes][\"e\"].max()\n",
    "\n",
    "            metrics_by_e_and_percentile = {\n",
    "                e: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for e in range(1, max_e + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by e value\n",
    "            for e, indexes in df.loc[method_indexes].groupby(\"e\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), n={n}, c={c}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by e and percentile\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_e_and_percentile[e][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_e,\n",
    "                    \"avg_vsr\": [np.nan] * max_e,\n",
    "                    \"std_vsr\": [np.nan] * max_e,\n",
    "                    \"avg_latency\": [np.nan] * max_e,\n",
    "                    \"std_latency\": [np.nan] * max_e,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for e in range(1, max_e + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_e_and_percentile[e][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][e-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][e-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][e-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][e-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][e-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][e-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                for e, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'e': e,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows).dropna()\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[percentile_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('e').first()\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Extra Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, c={c}, n={n}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_minedraft_improvement_vs_e(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            e_vals = np.sort(metrics_df.loc[percentile_indexes, 'e'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(e_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_c{c}_n{n}_'\n",
    "                            f'{percentile}_{target}_vs_e'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94bcc022",
   "metadata": {},
   "source": [
    "### VSR v. Percentile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f63dd04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs, n), setting_indexes in df[\n",
    "    (df['e'] > 0)\n",
    "    & (\n",
    "        ((df['draft_model'] == 'Qwen3-0.6B') & (df['batch_size'] == 16) & (df['n'] <= 2))\n",
    "        | ((df['draft_model'] == 'Llama-3.1-8B-Instruct') & (df['n'] == 2))\n",
    "    )\n",
    "    & (df['c'] <= 0)\n",
    "].groupby([\"target_model\", \"draft_model\", \"batch_size\", \"reqs\", \"n\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, k), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"k\"]).groups.items():\n",
    "            if method_name not in (\"ptetris\", \"tetris\"):\n",
    "                continue\n",
    "\n",
    "            max_e = 5\n",
    "            percentiles = [1.0, 0.98, 0.96, 0.94, 0.92, 0.9, 0.88, 0.86, 0.84, 0.82, 0.8, 0.78, 0.76, 0.74, 0.72, 0.7]\n",
    "\n",
    "            vsr_by_e_and_percentile = {\n",
    "                e: {\n",
    "                    percentage: [] for percentage in percentiles\n",
    "                } for e in range(1, max_e + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by e value\n",
    "            for e, indexes in df.loc[method_indexes].groupby(\"e\", sort=True).groups.items():\n",
    "                # if e > 2:\n",
    "                #     break\n",
    "\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k} + (extra={e}), n={n}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "                    for percentile in percentiles:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "\n",
    "                        # Store metrics for this bundle by e and percentile\n",
    "                        vsr_by_e_and_percentile[e][percentile].append(vsr)\n",
    "\n",
    "            vsr = {\n",
    "                percentile: {\n",
    "                    \"avg\": [np.nan] * max_e,\n",
    "                    \"std\": [np.nan] * max_e,\n",
    "                } for percentile in percentiles\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for e in range(1, max_e + 1):\n",
    "                for percentile in percentiles:\n",
    "                    metrics = vsr_by_e_and_percentile[e][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics:\n",
    "                        continue\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    vsr[percentile][\"avg\"][e-1] = np.mean(metrics)\n",
    "                    vsr[percentile][\"std\"][e-1] = np.std(metrics)\n",
    "\n",
    "            for percentile in percentiles:\n",
    "                metrics = {\n",
    "                    \"percentile\": int(percentile * 100),\n",
    "                    \"vsr\": vsr[percentile]\n",
    "                }\n",
    "                methods.setdefault(method_labels[method_name] + f\" (k={k})\", []).append(metrics)\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, metrics_list in methods.items():\n",
    "            for metrics in metrics_list:\n",
    "                percentile = metrics['percentile']\n",
    "                vsr = metrics['vsr']\n",
    "                for e, (avg_vsr, std_vsr) in enumerate(\n",
    "                    zip(vsr['avg'], vsr['std']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'percentile': percentile,\n",
    "                        'e': e,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                    })\n",
    "\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for e, e_indexes in metrics_df.groupby('e').groups.items():\n",
    "            plots = plt.subplots(figsize=(9, 9))\n",
    "\n",
    "            for method, method_indexes in metrics_df.loc[e_indexes].groupby('method').groups.items():\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('percentile').first().dropna()\n",
    "\n",
    "                fig, axs = plots\n",
    "\n",
    "                # Plot line with markers\n",
    "                axs.plot(grouped.index, grouped['avg_vsr'], marker='o', linestyle='-', label=method, linewidth=linewidth, markersize=markersize)\n",
    "                axs.fill_between(grouped.index, \n",
    "                                    grouped['avg_vsr'] - grouped['std_vsr'], \n",
    "                                    grouped['avg_vsr'] + grouped['std_vsr'], \n",
    "                                    alpha=0.2)\n",
    "\n",
    "                # Add labels and styling\n",
    "                axs.set_xlabel('Percentile', fontsize=axis_label_fontsize)\n",
    "                axs.set_ylabel('VSR', fontsize=axis_label_fontsize)\n",
    "                axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                axs.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "                axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, n={n}, e={e}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            fig, axs = plots\n",
    "            axs.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "            # Finalize plot\n",
    "            fig.tight_layout()\n",
    "            display(fig)\n",
    "            fig.savefig(\n",
    "                os.path.join(\n",
    "                    plot_dir,\n",
    "                    safename(\n",
    "                        f'{dataset}_'\n",
    "                        f'{target_model}_{draft_model}_'\n",
    "                        f'{reqs}_bs{batch_size}_n{n}_'\n",
    "                        f'e{e}_vsr_vs_percentage'\n",
    "                    ) + '.png',\n",
    "                ),\n",
    "                dpi=100\n",
    "            )\n",
    "            plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af56cefd",
   "metadata": {},
   "source": [
    "### Vary _m_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8f475f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 16\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "for (target_model, draft_model, reqs), setting_indexes in df[(df['e'] == 0) & (df['n'] == 1) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"reqs\"]).groups.items():\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, batch_size), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"batch_size\"]).groups.items():\n",
    "            if method_name not in (\"psd\", \"sd\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (m={batch_size})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    \"batch_size\": batch_size,\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                batch_size = trim['batch_size']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'batch_size': batch_size,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            # Create a color map for m values - MineDraft and Standard SD with same m get same color\n",
    "            unique_m_values = sorted(metrics_df.loc[percentile_indexes]['batch_size'].unique())\n",
    "            color_palette = plt.cm.tab10.colors\n",
    "            m_to_color = {m: color_palette[i % len(color_palette)] for i, m in enumerate(unique_m_values)}\n",
    "\n",
    "            # Sort methods by batch_size (m) in increasing order\n",
    "            method_to_batch = metrics_df.loc[percentile_indexes].groupby('method')['batch_size'].first().to_dict()\n",
    "            methods_sorted = sorted(\n",
    "                metrics_df.loc[percentile_indexes]['method'].unique(),\n",
    "                key=lambda x: (method_to_batch[x], x.startswith('Standard'))\n",
    "            )\n",
    "\n",
    "            for method in methods_sorted:\n",
    "                method_indexes = metrics_df.loc[percentile_indexes][metrics_df.loc[percentile_indexes]['method'] == method].index\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                \n",
    "                # Get m value from batch_size column\n",
    "                m_val = metrics_df.loc[method_indexes, 'batch_size'].iloc[0]\n",
    "                color = m_to_color[m_val]\n",
    "                \n",
    "                # Standard SD uses dotted line, MineDraft uses solid line\n",
    "                linestyle = ':' if method.startswith('Standard SD') else '-'\n",
    "\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle=linestyle, color=color, label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2, color=color)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='upper left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_max_minedraft_improvement(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_'\n",
    "                            f'{percentile}_{target}_vs_m'\n",
    "                        ) + '.png'\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2f8798a",
   "metadata": {},
   "source": [
    "### Vary _n_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c9d3dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.ticker import MaxNLocator, FuncFormatter\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "\n",
    "targets = ['gen_throughput', 'latency']\n",
    "axis_label_fontsize = 28\n",
    "legend_fontsize = 22\n",
    "linewidth = 3\n",
    "markersize = 12\n",
    "\n",
    "method_labels = {\n",
    "    'sd': 'Standard SD',\n",
    "    'pearl_sd': 'PEARL',\n",
    "    'psd': 'MineDraft (standalone)',\n",
    "    'tetris': 'Tetris',\n",
    "    'pearl_tetris': 'PEARL + Tetris',\n",
    "    'ptetris': 'MineDraft',\n",
    "    'eagle': 'EAGLE',\n",
    "    'pearl_eagle': 'PEARL + EAGLE',\n",
    "    'peagle': 'MineDraft + EAGLE',\n",
    "    'eagle3': 'EAGLE-3',\n",
    "    'pearl_eagle3': 'PEARL + EAGLE-3',\n",
    "    'peagle3': 'MineDraft + EAGLE-3',\n",
    "    'tetris_eagle': 'Tetris + EAGLE',\n",
    "    'tetris_eagle3': 'Tetris + EAGLE-3',\n",
    "    'ptetris_eagle': 'MineDraft + EAGLE',\n",
    "    'pearl_tetris_eagle': 'PEARL + Tetris + EAGLE',\n",
    "    'ptetris_eagle3': 'MineDraft + EAGLE-3',\n",
    "    'pearl_tetris_eagle3': 'PEARL + Tetris + EAGLE-3',\n",
    "}\n",
    "\n",
    "for (target_model, draft_model, batch_size, reqs), setting_indexes in df[(df['e'] == 0) & (df['c'] <= 0)].groupby(\n",
    "        [\"target_model\", \"draft_model\", \"batch_size\", \"reqs\"]).groups.items():\n",
    "    if target_model == \"Qwen3-32B\" and (draft_model != \"Qwen3-0.6B\" or batch_size != 16) or target_model == \"Meta-Llama-3.3-70B-Instruct-AWQ-INT4\" and batch_size != 64:\n",
    "        continue\n",
    "\n",
    "    for dataset_name, dataset_indexes in df.loc[setting_indexes].groupby(\"dataset\").groups.items():\n",
    "        dataset = dataset_name if dataset_name == \"ShareGPT\" else \"-\".join(w.capitalize() for w in dataset_name.replace(\"_\", \"-\").split(\"-\"))\n",
    "        methods = {}\n",
    "\n",
    "        for (method_name, n), method_indexes in df.loc[dataset_indexes].groupby([\"method\", \"n\"]).groups.items():\n",
    "            if method_name not in (\"psd\", \"sd\"):\n",
    "                continue\n",
    "\n",
    "            method = {\"name\": method_labels[method_name] + f' (n={n})'}\n",
    "            max_k = df.loc[method_indexes][\"k\"].max()\n",
    "\n",
    "            metrics_by_k_and_percentile = {\n",
    "                k: {\n",
    "                    1: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.8: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []},\n",
    "                    0.7: {\"gen_throughput\": [], \"vsr\": [], \"latency\": []}\n",
    "                } for k in range(1, max_k + 1)\n",
    "            }\n",
    "\n",
    "            # Collect metrics for all bundles by k value\n",
    "            for k, indexes in df.loc[method_indexes].groupby(\"k\", sort=True).groups.items():\n",
    "                for i, bundle in df.loc[indexes].iterrows():\n",
    "                    # Print total times of preemption\n",
    "                    if (preemptions := np.sum(bundle['step_preempted_requests'])) > 0:\n",
    "                        print(\n",
    "                            f\"{YELLOW}[WARN]{RESET} \"\n",
    "                            f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                            f\"bs={batch_size}, reqs={reqs}, \"\n",
    "                            f\"method={method_name}, k={k}, n={n}: \"\n",
    "                            f\"iter={i+1}, {BOLD}{YELLOW}Total Preemptions: {preemptions}{RESET}\"\n",
    "                        )\n",
    "\n",
    "                    for percentile in [1, 0.8, 0.7]:\n",
    "                        bound = int(len(bundle['step_drafted_tokens']) * percentile)\n",
    "\n",
    "                        gen_throughput = np.sum(bundle['step_generated_tokens'][:bound]) / np.sum(bundle['step_generation_times'][:bound])\n",
    "                        vsr = np.sum(bundle['step_accepted_tokens'][:bound]) / np.sum(bundle['step_verified_tokens'][:bound]) * 100\n",
    "                        e2el = bundle['total_latency'] / bundle['reqs']\n",
    "\n",
    "                        # Store metrics for this bundle by k and percentile\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"gen_throughput\"].append(gen_throughput)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"vsr\"].append(vsr)\n",
    "                        metrics_by_k_and_percentile[k][percentile][\"latency\"].append(e2el)\n",
    "\n",
    "            metrics_by_percentile = {\n",
    "                percentile: {\n",
    "                    \"avg_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"std_gen_throughput\": [np.nan] * max_k,\n",
    "                    \"avg_vsr\": [np.nan] * max_k,\n",
    "                    \"std_vsr\": [np.nan] * max_k,\n",
    "                    \"avg_latency\": [np.nan] * max_k,\n",
    "                    \"std_latency\": [np.nan] * max_k,\n",
    "                } for percentile in [1, 0.8, 0.7]\n",
    "            }\n",
    "\n",
    "            # Calculate means and stds\n",
    "            for k in range(1, max_k + 1):\n",
    "                for percentile in [1, 0.8, 0.7]:\n",
    "                    metrics = metrics_by_k_and_percentile[k][percentile]\n",
    "\n",
    "                    # Skip if no data collected\n",
    "                    if not metrics[\"gen_throughput\"]:\n",
    "                        continue\n",
    "\n",
    "                    # Calculate means\n",
    "                    avg_gen = np.mean(metrics[\"gen_throughput\"])\n",
    "                    avg_vsr = np.mean(metrics[\"vsr\"])\n",
    "                    avg_latency = np.mean(metrics[\"latency\"])\n",
    "\n",
    "                    # Calculate stds\n",
    "                    std_gen = np.std(metrics[\"gen_throughput\"])\n",
    "                    std_vsr = np.std(metrics[\"vsr\"])\n",
    "                    std_latency = np.std(metrics[\"latency\"])\n",
    "\n",
    "                    # Store back in the dictionary\n",
    "                    metrics_by_percentile[percentile][\"avg_gen_throughput\"][k-1] = avg_gen\n",
    "                    metrics_by_percentile[percentile][\"std_gen_throughput\"][k-1] = std_gen\n",
    "                    metrics_by_percentile[percentile][\"avg_vsr\"][k-1] = avg_vsr\n",
    "                    metrics_by_percentile[percentile][\"std_vsr\"][k-1] = std_vsr\n",
    "                    metrics_by_percentile[percentile][\"avg_latency\"][k-1] = avg_latency\n",
    "                    metrics_by_percentile[percentile][\"std_latency\"][k-1] = std_latency\n",
    "\n",
    "            trims = []\n",
    "            for percentile in [1, 0.8, 0.7]:\n",
    "                trims.append({\n",
    "                    \"percentile\": f\"{int(percentile*100)}%\",\n",
    "                    \"n\": n,\n",
    "                    **metrics_by_percentile[percentile]\n",
    "                })\n",
    "\n",
    "            method[\"trims\"] = trims\n",
    "            methods.update({method[\"name\"]: method[\"trims\"]})\n",
    "\n",
    "        # Prepare DataFrame for plotting\n",
    "        rows = []\n",
    "        for method, trims in methods.items():\n",
    "            for trim in trims:\n",
    "                percentile = trim['percentile']\n",
    "                n = trim['n']\n",
    "                for k, (avg_gen_throughput, std_gen_throughput, avg_vsr, std_vsr, avg_latency, std_latency) in enumerate(\n",
    "                    zip(trim['avg_gen_throughput'], trim['std_gen_throughput'],\n",
    "                        trim['avg_vsr'], trim['std_vsr'],\n",
    "                        trim['avg_latency'], trim['std_latency']),\n",
    "                    1\n",
    "                ):\n",
    "                    rows.append({\n",
    "                        'method': method,\n",
    "                        'k': k,\n",
    "                        'n': n,\n",
    "                        'percentile': percentile,\n",
    "                        'avg_gen_throughput': avg_gen_throughput,\n",
    "                        'std_gen_throughput': std_gen_throughput,\n",
    "                        'avg_vsr': avg_vsr,\n",
    "                        'std_vsr': std_vsr,\n",
    "                        'avg_latency': avg_latency,\n",
    "                        'std_latency': std_latency,\n",
    "                    })\n",
    "\n",
    "        if not rows:\n",
    "            continue\n",
    "\n",
    "        metrics_df = pd.DataFrame(rows)\n",
    "\n",
    "        for percentile, percentile_indexes in metrics_df.groupby('percentile').groups.items():\n",
    "            plots = {target: plt.subplots(figsize=(9, 9)) for target in targets}\n",
    "\n",
    "            # Create a color map for n values - MineDraft and Standard SD with same n get same color\n",
    "            unique_n_values = sorted(metrics_df.loc[percentile_indexes]['n'].unique())\n",
    "            color_palette = plt.cm.tab10.colors\n",
    "            n_to_color = {n: color_palette[i % len(color_palette)] for i, n in enumerate(unique_n_values)}\n",
    "\n",
    "            # Sort methods by n in increasing order\n",
    "            method_to_n = metrics_df.loc[percentile_indexes].groupby('method')['n'].first().to_dict()\n",
    "            methods_sorted = sorted(\n",
    "                metrics_df.loc[percentile_indexes]['method'].unique(),\n",
    "                key=lambda x: (method_to_n[x], x.startswith('Standard'))\n",
    "            )\n",
    "\n",
    "            for method in methods_sorted:\n",
    "                method_indexes = metrics_df.loc[percentile_indexes][metrics_df.loc[percentile_indexes]['method'] == method].index\n",
    "                grouped = metrics_df.loc[method_indexes].groupby('k').first().dropna()\n",
    "                \n",
    "                # Get n value from n column\n",
    "                n_val = metrics_df.loc[method_indexes, 'n'].iloc[0]\n",
    "                color = n_to_color[n_val]\n",
    "                \n",
    "                # Standard SD uses dotted line, MineDraft uses solid line\n",
    "                linestyle = ':' if method.startswith('Standard SD') else '-'\n",
    "\n",
    "                for target in targets:\n",
    "                    fig, axs = plots[target]\n",
    "\n",
    "                    # Plot line with markers\n",
    "                    axs.plot(grouped.index, grouped[f'avg_{target}'], marker='o', linestyle=linestyle, color=color, label=method, linewidth=linewidth, markersize=markersize)\n",
    "                    axs.fill_between(grouped.index, \n",
    "                                     grouped[f'avg_{target}'] - grouped[f'std_{target}'], \n",
    "                                     grouped[f'avg_{target}'] + grouped[f'std_{target}'], \n",
    "                                     alpha=0.2, color=color)\n",
    "\n",
    "                    # Add labels and styling\n",
    "                    axs.set_xlabel('No. Speculative Tokens', fontsize=axis_label_fontsize)\n",
    "                    ylabels = {'gen_throughput': 'Throughput (tokens/s)', 'latency': 'E2E Latency (ms)', 'vsr': 'VSR'}\n",
    "                    axs.set_ylabel(ylabels.get(target, target.capitalize()), fontsize=axis_label_fontsize)\n",
    "                    axs.tick_params(axis='both', which='major', labelsize=axis_label_fontsize)\n",
    "                    axs.legend(loc='center left' if target == 'latency' else 'lower left', fontsize=legend_fontsize)\n",
    "                    axs.grid(True)\n",
    "\n",
    "            report_context = (\n",
    "                f\"{dataset}, target={target_model}, draft={draft_model}, \"\n",
    "                f\"bs={batch_size}, reqs={reqs}, trim={percentile}\"\n",
    "            )\n",
    "            print(report_context)\n",
    "            # report_max_minedraft_improvement(\n",
    "            #     metrics_df.loc[percentile_indexes],\n",
    "            #     targets,\n",
    "            #     report_context\n",
    "            # )\n",
    "\n",
    "            # Force integer x-ticks across subplots\n",
    "            k_vals = sorted(metrics_df.loc[percentile_indexes, 'k'].unique())\n",
    "            for target, (fig, ax) in plots.items():\n",
    "                ax.set_xticks(k_vals)\n",
    "                ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{int(x)}'))\n",
    "\n",
    "                # Finalize plot\n",
    "                fig.tight_layout()\n",
    "                display(fig)\n",
    "                fig.savefig(\n",
    "                    os.path.join(\n",
    "                        plot_dir,\n",
    "                        safename(\n",
    "                            f'{dataset}_'\n",
    "                            f'{target_model}_{draft_model}_'\n",
    "                            f'{reqs}_bs{batch_size}_'\n",
    "                            f'{percentile}_{target}_vs_n'\n",
    "                        ) + '.png',\n",
    "                    ),\n",
    "                    dpi=100\n",
    "                )\n",
    "                plt.close(fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tetris",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
